Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming authored Sep 15, 2024
1 parent c162d38 commit 91a2f63
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions kan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def augment_input(orig_vars, aux_vars, x):
return x


def batch_jacobian(func, x, create_graph=False):
def batch_jacobian(func, x, create_graph=False, mode='scalar'):
'''
jacobian
Expand All @@ -408,7 +408,10 @@ def batch_jacobian(func, x, create_graph=False):
# x in shape (Batch, Length)
def _func_sum(x):
return func(x).sum(dim=0)
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
if mode == 'scalar':
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
elif mode == 'vector':
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

def batch_hessian(model, x, create_graph=False):
'''
Expand Down Expand Up @@ -588,4 +591,4 @@ def model2param(model):
p = torch.tensor([]).to(model.device)
for params in model.parameters():
p = torch.cat([p, params.reshape(-1,)], dim=0)
return p
return p

0 comments on commit 91a2f63

Please sign in to comment.