PyTorch Lightning v2.0 has made two major changes to pl.LightningModule
methods.
First, training_epoch_end(outputs)
methods have been removed in favor of on_training_epoch_end()
.
Notice how the _on
function is not quite the same: the outputs
argument has been removed, requiring users to explicitly store outputs as a self.
property. The release notes have a good example:
Before:
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
...
return {"loss": loss, "banana": banana}
# `outputs` is a list of all bananas returned in the epoch
def training_epoch_end(self, outputs):
avg_banana = torch.cat(out["banana"] for out in outputs).mean()
Now:
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# 1. Create a list to hold the outputs of `*_step`
self.bananas = []
def training_step(self, batch, batch_idx):
...
# 2. Add the outputs to the list
# You should be aware of the implications on memory usage
self.bananas.append(banana)
return loss
# 3. Rename the hook to `on_*_epoch_end`
def on_training_epoch_end(self):
# 4. Do something with all outputs
avg_banana = torch.cat(self.bananas).mean()
# Don't forget to clear the memory for the next epoch!
self.bananas.clear()
The idea is to make it obvious that using epoch_end
actions requires all the batch-level outputs to be stored until the end of the epoch rolls around.
If you were using training_epoch_end(outputs)
to update your TorchMetric
metrics, as the TorchMetrics docs suggest (my issue), this will now raise an error.
You might think you could replace the TorchMetrics updates in the similar training_step_end(outputs)
methods. Unfortunately, and much less obviously from the release notes, these have also been removed - and will now silently do nothing.
Instead, you can use on_train_batch_end(outputs)
. For example, in Zoobot:
def on_train_batch_end(self, outputs, *args):
self.train_loss_metric(outputs['loss'])
self.log(
"finetuning/train_loss",
self.train_loss_metric,
prog_bar=self.prog_bar,
on_step=False,
on_epoch=True
)
This appears to work well for multi-GPU training. It also avoids storing batch outputs in memory. Overall I’m grateful to the Lightning team for the new release - I hadn’t noticed the previous memory-consuming behaviour before.
The Zoobot changes are now on dev
and will be live on main
and PyPI in the next few days.