I realised that some code like this

def train():
    for epoch in epochs:
        for batch in batches:
            ...

can be modified to exit after a specified maximum number of training steps as follows

def train(max_steps: int = 20):
    step = 0
    for epoch in epochs:
        for batch in batches:
            step += 1
            if step >= max_steps:
                return # early return here from nested loop

But crucially this doesn’t work if you do not combine both loops in the same function, e.g. if we use a train_one_epoch function as is popular e.g. in the videowalk codebase:

MAX_STEPS: int = 20
 
def train_one_epoch(step: int) -> int:
    for batch in batches:
        step += 1
        if step >= MAX_STEPS:
            return step
    return step
    
def train(max_steps: int = 20) -> None:
    step = 0
    for epoch in epochs:
        step = train_one_epoch()
        if step >= MAX_STEPS:
            return

We have to explicitly return from both functions and also find some way of passing the max_steps configuration parameter to both functions too, here for example as a global variable.


Disclaimer: Code written on the fly for myself to remember this in the Markdown - no testing here.