diff --git a/torchinterp1d/interp1d.py b/torchinterp1d/interp1d.py index d30a939..059338d 100644 --- a/torchinterp1d/interp1d.py +++ b/torchinterp1d/interp1d.py @@ -97,7 +97,10 @@ def forward(ctx, x, y, xnew, out=None): if v['xnew'].shape[0] == 1: v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1) - torch.searchsorted(v['x'].contiguous(), + # the squeeze is because torch.searchsorted does accept either a nd with + # matching shapes for x and xnew or a 1d vector for x. Here we would + # have (1,len) for x sometimes + torch.searchsorted(v['x'].contiguous().squeeze(), v['xnew'].contiguous(), out=ind) # the `-1` is because searchsorted looks for the index where the values