I realised that some code like this
def train():
for epoch in epochs:
for batch in batches:
...
can be modified to exit after a specified maximum number of training steps as follows
def train(max_steps: int = 20):
step = 0
for epoch in epochs:
for batch in batches:
step += 1
if step >= max_steps:
return # early return here from nested loop
But crucially this doesn’t work if you do not combine both loops in the same function, e.g. if we use a train_one_epoch
function as is popular e.g. in the videowalk codebase:
MAX_STEPS: int = 20
def train_one_epoch(step: int) -> int:
for batch in batches:
step += 1
if step >= MAX_STEPS:
return step
return step
def train(max_steps: int = 20) -> None:
step = 0
for epoch in epochs:
step = train_one_epoch()
if step >= MAX_STEPS:
return
We have to explicitly return from both functions and also find some way of passing the max_steps configuration parameter to both functions too, here for example as a global variable.
Disclaimer: Code written on the fly for myself to remember this in the Markdown - no testing here.