Having recently switched to using pytorch for modeling, after primarily building neural networks in tensorflow / keras, I have been enjoying how easy it is to write new (automatically differentiable) functions and layers.
I did this recently: I wanted to create a layer that would reverse gradients during a backward pass. There are many reasons you might want to do that, one of which is to make parts of your network learn the opposite of what it would learn from minimising the loss.
The result is here, but most of the logic is embedded below.
As you can see, the main object to use is a “layer”, also known as a module in
pytorch. A module can be a single layer (e.g.
torch.nn.Linear), or a whole
network, and the fact that they are fundamentally the same object makes it easy
to construct complex networks out of existing, complex building blocks.
However, implementing a custom
torch.nn.Module does not allow you to define
your own auto-differentiation functions. Since I wanted to manipulate the
gradients during the backward pass, I had to do this.
torch.autograd.Function objects implement both a forward and backward method,
and they behave as you expect: the forward method receives the input and returns
the output, while the backward method receives the gradient of the output and
returns the gradient of the inputs.
When it came to test this code, I wrote just one test for the
this does nothing except call the function, I figured all my code would be
But - figuring that code is covered by tests is good, but knowing is better. So I checked this with coverage.py, a module that tells you which lines of code were run during a particular session and which were not.
And while the tests passed, the coverage indicated that the backward call never happened!
So it seems like the function did what it was supposed to - reversed the gradients - but did not get called. However, after some puzzling and some help on the pytorch forums, I realised that the way coverage.py tests if a line ran (using a trace, which is a function that gets run after every line of python code) does not in fact work for the backward method.
It’s called by C++ code, not python code, and so coverage doesn’t know that it was called.
So: when testing for coverage in custom pytorch functions, exclude the backward passes.