Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature extraction code puts models into eval() and inference_mode() behind the scenes without printing an explicit message #892

Closed
GeorgeBatch opened this issue Dec 3, 2024 · 1 comment

Comments

@GeorgeBatch
Copy link
Contributor

  • TIA Toolbox version: develop branch
  • Python version: 3.11
  • Operating System: linux

Description

I realised that in my code, like in the foundation models example notebook, we do not explicitly put the model into the .eval() mode or activate torch.no_grad() or torch.inference_mode() context managers.

The same is true for the older example notebooks on which I based my code:

I got worried that I was using models in train mode and that I would need to recompute all the features I computed over the last couple of weeks, but then I checked the source code and realised that we rely on the _infer_batch() function, which runs model.eval() for all models that use it (line 173):

def _infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
device: str,
) -> dict[str, np.ndarray]:
"""Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".
"""
img_patches_device = batch_data.to(device=device).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()
# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return output.cpu().numpy()

Suggestion

I think this behaviour is as it should be, but maybe adding a message specifying that the models are being put into eval mode would be beneficial.

@shaneahmed
Copy link
Member

Thanks @GeorgeBatch for the suggestion. However, all the models in the TIAToolbox are for inference at the moment. Printing such a message would be redundant unless we include training codes for these models in the toolbox.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants