Debugging tips for neural network training

December 16, 2022

Andrej Karpathy wrote a great post a few years ago about training neural networks. Here are a few additional things that I follow during implementation, with a bias towards debugging large language models.

Log anything and everything

Weights

Set up logging upfront as extensively as possible. I use wandb for experiment reporting. I find it the best option on the market right now for experiment tracking with near unlimited personal use.

Log tensors of the training process, especially the updates in layer weights. Look out for gradients that go to zero and stay there. Sometimes this is just a biproduct of the current loss landscapes but it's often a signal that the network has saturated the learning it can actually do.

Specifically,

  • Log gradient magnitude of each layer
  • Log gradient distribution of each layer
  • Log weight matrix norm
  • Log weight matrix distribution

Both of the distribution logs are handled through wandb’s built-in watch utility.

wandb.watch(model, log="all", log_freq=10)

Matrix norm can be easily done by iterating over the current module parameters. Both give you slightly different views on the same data.

Datapoints

Dataset examples are worth their weight in gold over time. Loss metrics can successfully go down if there's a bug in the loss calculation, even though the network is learning nothing of value. Bugs are harder to hide in human intereptable datapoints.

  • Log training set and inference set examples and predictions for each example within a fixed setting of batches (perhaps modulo by 100 or 1000 depending on dataset size). Record multimedia where you can to visualize images, videos, or decoded text.
  • Use a tokenizer.decode() anywhere you have logits that can be converted into token indexes. This is most often at the input (pre-embedding) and output (post-linear projection to vocab space) but it can also be places like additional padding generation or masking.
  • For sequence or token embeddings, log their projection directly.

Final layer logits

In single-label, multi-class problems you're likely using softmax as part of the loss function. Since softmax isn't a hard maximization algorithm, you're encouraging your model to create a distribution of weights over the correct class. If you log the final layer logits on a graph you should notice that the majority of probability mass converges on the right value over time (especially during an overfitting run). The further you go during your overfit, the more dramatic this max point should be. But there should be a noticable transition of the weights from scattered equally to just around the correct value.

This provides an additional sanity check on the progress of the overfit. It ensures that the model is statistically evolving in the way that you expect, with a the smooth increase in the softly maximized value. Often looking at changes over time is a helpful way to analyze training.

When in doubt, always log.

Start with a simple architecture substitute

Most of the core ML activities have common abstraction layers. Transformers can be replaced with RNNs, a Resnet for a CNN. These simpler approaches won’t be able to reach the precision you want but might be able to prove if there’s a problem with the gradient flows of the overall harness. They also train much faster, which is useful if you’re trying to run a quick sanity check of a new overfitting pipeline.

I've also noticed that newer laptops have gotten surprisingly fast when it comes to matrix multiplication. Certainly not fast enough to train a whole network on but I frequently find myself doing initial prototyping locally now. This work is focused on overfitting on simple architectures, sanity checking data, and making sure vectorization is logically correct.

Make randomness reproducable

Modern models have a lot of randomness built into them.

  • Data augmentation
  • Masked language modeling
  • Dropout and regularization

As intended, these techniques help to generalize models by making it difficult for networks to overfit. But during overfitting you do want them to overfit, ideally aggressively memorizing the input to test model capacity and the training harness. It’s difficult to determine whether there is an issue during overfitting if there's randomness to the inputs, outputs, or losses.

I used to manually disable all the random elements of the model when overfitting: setting dropout to zero, disabling data augmentation, etc. That approach suffers from having a lot of if elif statements, plus doesn't necessarily capture if an imported module has some randomness baked into its implementation. Instead of this approach I've started seeding the model with a fixed seed during each training and validation step. In Pytorch-Lightning this looks something like:

CONSTANT_SEED = 60

class MySmartModule:
    def training_step(self, batch):
        if self.trainer.overfit_batches:
            print("Will reset seed for reproducable overfitting")
            pl.seed_everything(CONSTANT_SEED)

    def validation_step(self, batch):
        if self.trainer.overfit_batches:
            print("Will reset seed for reproducable overfitting")
            pl.seed_everything(CONSTANT_SEED)

This doesn't remove randomness directly but it should make the random values consistent across each training and validation step, which is effectively the same thing. This should allow the model to overfit as well as a zero randomness implementation. Double check via logging to confirm the input values are indeed equal.

Overfit on 1, then 2, then 5

Any sufficiently sized network should be able to go to 0 loss on a handful of datapoints. I typically start with one example (1 batch, batch size 1). This should be trivially learnable since there's not even a need to create a discriminative output space. If this succeeds, expand to 2 distinct examples and then 5 distinct examples.

Write each custom vectorization twice

This sounds a bit like overkill but it's saved many a headache. Any time I'm doing anything other than feedforward passing of values, I restructure tensor transformations into separate functions. I then re-write this logic using standard for-loops and single based indexing of the tensors. Then run a few examples through and make sure their values match. It's the easiest way to validate vector broadcasting and other parallel operations are working as intended.

Specifically I wrap both implementations within a class descriptive of the transformation. Let's say we are going to write a function that masks out colors of particular values. I start with the for-loop implementation, going through index by index. Attempts to call the vectorized pipeline fail loudly. The original class structure looks something like this:

class ColorMasking:
    def __init__(self, vectorize):
        self.vectorize = vectorize

    def __call__(self, *args, **kwargs):
        if self.vectorize:
            return self.vectorized(*args, **kwargs)
        else:
            logging.warning("Using greedy implementation of ColorMasking")
            return self.greedy(*args, **kwargs)

    def greedy(self, img):
        for y in range(img.shape[0]):
            for x in range(img.shape[1]):
                ...

    def vectorized(self, img):
        raise NotImplementedError()

This class lets you switch easily between explicit (slow but more likely correct) and vectorization (fast but more likely to introduce bugs). It also builds in a unit-testable code block that makes it easier to check for implementation issues over time. Within your neural network module you can then choose if you want to switch to vectorization across the board, or sanity check a few epochs with the manual vectorization.

class MySmartModule(torch.nn.Module):
    def __init__(self):
        self.vectorize = False

    def forward(img):
        mask = ColorMasking(vectorize=self.vectorize)(img)

This also compliments the implementation process where you might only have an explicit implementation thus far:

class MySmartModule(torch.nn.Module):
    def __init__(self):
        self.vectorize = False

    def forward(img):
        mask = ShapeMask(vectorize=self.vectorize)
        mask = ColorMasking(vectorize=False)(img)

Sometimes I'll run the overfitting job directly with this non-vectorized function to sanity check that it's doing what I'd like. Othertimes because of speed constraints I'll jump to writing the vectorized logic.

Unit test helper functions

Add heavy unit test groups for vectorization, data loaders, and the training pipeline.

Vectorization: As a continuation of the above section, validate that vectorized code works correctly. Define the expected transformations via a few hand-written examples. Try them with different tensor sizes and log expected weights or some expected transformations.

Data Loader: This goes for the data loader as well. Guarantee that transformations are as expected by reverse-converting them if possible. Take the argmax of text logits and retrieve the text, convert image pixels to an actual PIL that can be compared to a static artifact, etc.

Training Pipeline: An additional integration test can validate some behavior of the training pipeline. One of the biggest issues is typically gradient flow - a loss that doesn't correctly propagate to earlier tensors in the network. At best you're missing out on validable learning - at worse, the earlier layers will remain randomly initialized and the rest of the network will be guessing about random input noise. One route around is a test that passes in some synthetic data and steps over the gradient weights and asserts that each norm is non-zero. There should be some learning for every layer of the network.

My training pipelines tend to kick off via a CLI train executable. To make sure the unit tests satisfy on every training run, I add a pytest run command to this implementation before the harness is initialized.

@click.command()
def train():
    pytest.main()

    # Training block

Avoid global variables

Global variables are generally bad form in regular software engineering and they can be just as bad in ML. Jupyter is great for prototyping but it's easy for bugs to leak in when things are defined as regular cells. Even if you refactor some of the logic into functions, they might still be grabbing variables in global state.

As a general workflow, I prototype entirely in global space. It's easier to pass variables around and make sure tensors are sized correctly. Here I'm usually just working with one batch.

Before kicking off a full training run I'll refactor all the cells into separate functions. This ensures that there's no global variable leakage. It has caught some trivial errors before where the same value is being unintentionally re-used multiple times versus iterating over a larger list.

Ensure (static) batches are the same over time

When adding in custom dataset, custom loaders, and custom collate functions, batches may subtly skew over time. This is particularly apt to happen when manipulating dictionaries in place or aggregating some values. I use this snippet to check for equality over time.

Needless to say, this won't work if you are introducing random augmentations within the dataloader class. For those cases I temporarily disable the transformation and then run this validation. You can also selectively whitelist keys that should stay the same over time, while allowing variation in the keys that are part of the random transformations.

first_sample = next(iter(train_loader))
second_sample = next(iter(train_loader))

print("Will check equality...")
for key in first_sample.keys():
    first_value = first_sample[key]
    second_value = second_sample[key]

    if isinstance(first_value, torch.Tensor):
        if not torch.equal(first_value, second_value):
            print(first_value)
            print(second_value)
            raise ValueError(f"Unequal iterations: {key} (torch tensor)")
    else:
        if first_value != second_value:
            print(first_value)
            print(second_value)
            raise ValueError(f"Unequal iterations: {key}")
print("Success...")

Synthetically generate data of different sizes

Within larger networks, especially with back propagation through time, gradients might disappear earlier on in the network. I’ve found a helpful method to debug these issues is to synthetically generate new datapoints of different sizes.

This is a natural fit if your data loader ends up accepting files on disk, which is common in most of the large scale architectures I end up building. Write a function that dumps a new dataset to disk in the correct format. Output values don't matter here since the network should just memorize the raw values during overfitting.

tokenizer = Tokenizer()
labels = ["A", "B", "C"]

@contextmanager
def create_synthetic_datapoint(text_length):
    random.sample(tokenizer.vocab, text_length)
    random.choice(labels)

    with tempfile.TemporaryDirectory() as directory:
        yield directory

with create_synthetic_datapoint(50) as path:
    train_dataset = MyDataset([path])

    trainer.overfit(model, train_dataset)

Use einops when possible

Any time you need a tensor transformation (view, transpose, stacking, etc) I try to fit it into an einop. They make these manipulations more descriptive by referencing axes by string values for what they mean. They can also assume some dimensions where you might otherwise need a .shape arithmetic. I try to use full words or variable names within these strings unless something is readily obvious, like b for batch.

x = rearrange(x, "b height width embedding -> b (height width) embedding")

I've found these einops make it far easier to debug once I've been away from a function for a few days.

Conclusions

The road to successful training is getting easier thanks to great open source projects and an increasing trend to publish code alongside publications. But still, when trying something novel (either on a dataset or with a new model architecture) the road to success is winding. A one-character indexing bug can throw a result from SOTA to barely outperforming the baseline.

I had an old colleague who said "everything is possible in software, you just have to spend enough time building it." The challenging thing with ML is that some things are not possible - at least not with the current state of the art in data and architectures. ML research is the process of minimizing the chance of logical errors as much as possible. Because a failure might be because something just isn't possible - or because it might be a bug. Being dilligent and defensive upfront is the best way to ensure failures are the former and not the latter. Feeling comfortable around legitimate failure is the best way to evolve into an experiment that does actually succeed.

Thanks to Richard Diehl Martinez and Geoffrey Angus for their feedback and suggestions of some additional debugging techniques.