diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index fc4a4f2..063e6ab 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -125,8 +125,8 @@ def merge_batch_norm(modules, batch_norm): for module in modules: if module.bias is None: - object.__setattr__( - module, 'bias', torch.zeros(1, device=module.weight.device, dtype=module.weight.dtype) + setattr( + module, 'bias', torch.nn.Parameter(torch.zeros(1, device=module.weight.device, dtype=module.weight.dtype)) ) index = (slice(None), *((None,) * (module.weight.ndim - 1)))