PyTorch Lightning v2.0 has made two major changes to pl.LightningModule methods.
First, training_epoch_end(outputs) methods have been removed in favor of on_training_epoch_end().
Notice how the _on function is not quite the same: the outputs argument has been removed, requiring users to explicitly store outputs as a self. property. The release notes have a good example:
Before:
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
...
return {"loss": loss, "banana": banana}
# `outputs` is a list of all bananas returned in the epoch
def training_epoch_end(self, outputs):
avg_banana = torch.cat(out["banana"] for out in outputs).mean() Now:
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# 1. Create a list to hold the outputs of `*_step`
self.bananas = []
def training_step(self, batch, batch_idx):
...
# 2. Add the outputs to the list
# You should be aware of the implications on memory usage
self.bananas.append(banana)
return loss
# 3. Rename the hook to `on_*_epoch_end`
def on_training_epoch_end(self):
# 4. Do something with all outputs
avg_banana = torch.cat(self.bananas).mean()
# Don't forget to clear the memory for the next epoch!
self.bananas.clear()The idea is to make it obvious that using epoch_end actions requires all the batch-level outputs to be stored until the end of the epoch rolls around.
If you were using training_epoch_end(outputs) to update your TorchMetric metrics, as the TorchMetrics docs suggest (my issue), this will now raise an error.
You might think you could replace the TorchMetrics updates in the similar training_step_end(outputs) methods. Unfortunately, and much less obviously from the release notes, these have also been removed - and will now silently do nothing.
Instead, you can use on_train_batch_end(outputs). For example, in Zoobot:
def on_train_batch_end(self, outputs, *args):
self.train_loss_metric(outputs['loss'])
self.log(
"finetuning/train_loss",
self.train_loss_metric,
prog_bar=self.prog_bar,
on_step=False,
on_epoch=True
)This appears to work well for multi-GPU training. It also avoids storing batch outputs in memory. Overall I’m grateful to the Lightning team for the new release - I hadn’t noticed the previous memory-consuming behaviour before.
The Zoobot changes are now on dev and will be live on main and PyPI in the next few days.