Skip to content

Monky patch

Monkey patch a method

To patch the forward of a nn.Module, define a closure that keeps temporary variables and returns your new forward:

import torch.nn as nn

class A(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name

    def forward(self):
        print(f'original forward of {self.name}')

a = A('a')
b = A('b')

for name, m in zip(['a', 'b'], [a, b]):

    def make_forward():
        # record the current name in closure !
        cur_name = name
        def _forward():
            print(f'patched forward of {cur_name}')
        return _forward

    m.forward = make_forward()

a()
b()

Output:

patched forward of a
patched forward of b

However, you cannot patch magic methods like __call__ by this:

class A:
    def __init__(self, name):
        super().__init__()
        self.name = name

    def __call__(self):
        print(f'original forward of {self.name}')

a = A('a')
b = A('b')

for name, m in zip(['a', 'b'], [a, b]):

    def make_call():
        # record the current name in closure !
        cur_name = name
        def _call():
            print(f'patched forward of {cur_name}')
        return _call

    m.__call__ = make_call()

a()
b()

Output:

original forward of a
original forward of b

This is because __call__ is looked-up with respect to the class instead of instance, so we are still calling the original __call__.

We have to patch the class to make this work, and cast instances to the derived class:

class A:
    def __init__(self, name):
        super().__init__()
        self.name = name

    def __call__(self):
        print(f'original forward of {self.name}')

a = A('a')
b = A('b')

# a derived class that redirect __call__ to our patched call
class B(A):
    def __call__(self):
        self.patched_call()

for name, m in zip(['a', 'b'], [a, b]):

    def make_call():
        # record the current name in closure !
        cur_name = name
        def _call():
            print(f'patched forward of {cur_name}')

        return _call

    m.__class__ = B # magic cast!
    m.patched_call = make_call()

a()
b()

Output:

patched forward of a
patched forward of b