diff --git a/kan/utils.py b/kan/utils.py index 273b574b..abb4d558 100644 --- a/kan/utils.py +++ b/kan/utils.py @@ -384,7 +384,7 @@ def augment_input(orig_vars, aux_vars, x): return x -def batch_jacobian(func, x, create_graph=False): +def batch_jacobian(func, x, create_graph=False, mode='scalar'): ''' jacobian @@ -408,7 +408,10 @@ def batch_jacobian(func, x, create_graph=False): # x in shape (Batch, Length) def _func_sum(x): return func(x).sum(dim=0) - return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0] + if mode == 'scalar': + return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0] + elif mode == 'vector': + return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2) def batch_hessian(model, x, create_graph=False): ''' @@ -588,4 +591,4 @@ def model2param(model): p = torch.tensor([]).to(model.device) for params in model.parameters(): p = torch.cat([p, params.reshape(-1,)], dim=0) - return p \ No newline at end of file + return p