Victor Poughon

How to get all intermediate layers outputs in PyTorch

In PyTorch, did you ever want to get all layers inputs and outputs and not just the final output of the model? I needed that too recently for my torchlensmaker project, so here's a little snippet!

import torch.nn as nn

from typing import Any, Iterator
from dataclasses import dataclass


@dataclass
class ModuleEvalContext:
    module: nn.Module
    inputs: Any
    outputs: Any

    def __iter__(self) -> Iterator[Any]:
        return iter((self.module, self.inputs, self.outputs))


def full_forward(
    module: nn.Module, inputs: Any
) -> tuple[list[ModuleEvalContext], Any]:
    """
    Forward evaluate a model, but returns all intermediate inputs and outputs.

    This is kind of like normal forward evaluation of a model, as in:

        > outputs = model(inputs)

    except that all intermediate layers inputs and outputs are returned as a
    list of tree element tuples (module, inputs, outputs):

        > execute_list, output = full_forward(model, inputs)
        > for module, inputs, outputs in execute_list:
        >     print(module, inputs, outputs)

    Args:
        module: PyTorch nn.Module to evaluate
        inputs: input data to the module

    Returns:
        execute_list: list of (module, inputs, outputs)
        outputs: output of the top level module execution
    """

    execute_list = []

    # Define the forward hook
    def hook(mod: nn.Module, inp: Any, out: Any) -> None:
        # inp[0] here restricts us to forward() first argument
        # so this only works with single argument forward() functions
        execute_list.append(ModuleEvalContext(mod, inp[0], out))

    # Register forward hooks to every module recursively
    hooks = []
    for mod in module.modules():
        hooks.append(mod.register_forward_hook(hook))

    # Evaluate the full model, then remove all hooks
    try:
        outputs = module(inputs)
    finally:
        for h in hooks:
            h.remove()

    return execute_list, outputs

And this is how you use it:

execute_list, output = tlm.full_forward(model, X)

for module, inputs, outputs in execute_list:
	# Do something here:
	...
	print(module, inputs, outputs)

This is based on PyTorch's register_forward_hook and is really flexible! Enjoy :)