Skip to content

Commit

Permalink
there was a bug with 1d inputs when using torch.searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
aliutkus committed Oct 3, 2022
1 parent 9898ae5 commit ac29a3a
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchinterp1d/interp1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def forward(ctx, x, y, xnew, out=None):
if v['xnew'].shape[0] == 1:
v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1)

torch.searchsorted(v['x'].contiguous(),
# the squeeze is because torch.searchsorted does accept either a nd with
# matching shapes for x and xnew or a 1d vector for x. Here we would
# have (1,len) for x sometimes
torch.searchsorted(v['x'].contiguous().squeeze(),
v['xnew'].contiguous(), out=ind)

# the `-1` is because searchsorted looks for the index where the values
Expand Down

0 comments on commit ac29a3a

Please sign in to comment.