Skip to content

Commit

Permalink
fix single-target eval()
Browse files Browse the repository at this point in the history
torchmetrics multiclass (single-target) metrics expect the labels as a vector of per-class integer label indices, rather than one-hot labels

adds and reformats documentation of attributes

adds test that catches the old, wrong behavior (before test was setting self.single_target=True after init, which did not update metrics to single-target metrics)
  • Loading branch information
sammlapp committed Sep 17, 2024
1 parent 7a2ea1b commit b86b91b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 28 deletions.
90 changes: 64 additions & 26 deletions opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,13 @@ def training_step(self, samples, batch_idx):
# compute and log any metrics in self.torch_metrics
# TODO: consider using validation set names rather than integer index
# (would have to store a set of names for the validation set)

# single-target torchmetrics expect labels as integer class indices rather than one-hot
y = batch_labels.argmax(dim=1) if self.single_target else batch_labels
# TODO: does not allow soft labels, but some torchmetrics expect long type?
batch_metrics = {
f"train_{name}": metric.to(self.device)(
output.detach(), batch_labels.detach().long()
output.detach(), y.detach().long()
).cpu()
for name, metric in self.torch_metrics.items()
}
Expand Down Expand Up @@ -840,50 +843,82 @@ class SpectrogramClassifier(SpectrogramModule, torch.nn.Module):
def __init__(self, *args, **kwargs):
"""defines pure pytorch train, predict, and eval methods for a spectrogram classifier
subclasses SpectrogramModule, defines methods that are used for pure PyTorch workflow.
To use lightning, see ml.lightning.LightningSpectrogramModule.
subclasses SpectrogramModule, defines methods that are used for pure PyTorch workflow. To
use lightning, see ml.lightning.LightningSpectrogramModule.
Args:
see SpectrogramModule for arguments
Methods:
predict: generate predictions across a set of audio files or a dataframe defining audio files and start/end clip times
train: fit the machine learning model using training data and evaluate with validation data
save: save the model to a file
load: load the model from a file
predict: generate predictions across a set of audio files or a dataframe defining audio
files and start/end clip times
train: fit the machine learning model using training data and evaluate with validation
data
save: save the model to a file load: load the model from a file
embed: generate embeddings for a set of audio files
generate_embeddings: generate embeddings for a set of audio files
generate_samples: creates preprocessed sample tensors, same arguments as predict()
generate_cams: generate gradient activation maps for a set of audio files
eval: evaluate performance by applying self.torch_metrics to predictions and labels
run_validation: test accuracy by running inference on a validation set and computing metrics
change_classes: change the classes that the model predicts
run_validation: test accuracy by running inference on a validation set and computing
metrics change_classes: change the classes that the model predicts
freeze_feature_extractor: freeze all layers except the classifier
freeze_layers_except: freeze all parameters of a model, optionally exluding some layers
train_dataloader: create dataloader for training
predict_dataloader: create dataloader for inference (predict/validate/test)
train_dataloader: create dataloader for training predict_dataloader: create dataloader
for inference (predict/validate/test)
save_weights: save just the self.network state dict to a file
load_weights: load just the self.network state dict from a file
Editable Attributes & Properties:
single_target: (bool) if True, predict only class with max score
device: (torch.device or str) device to use for training and inference
preprocessor: object defining preprocessing and augmentation operations, e.g. SpectrogramPreprocessor
device: (torch.device or str) device to use for training and inference preprocessor:
object defining preprocessing and augmentation operations, e.g. SpectrogramPreprocessor
network: pytorch model object, e.g. Resnet18
loss_fn: callable object to use for calculating loss during training, e.g. BCEWithLogitsLoss_hot()
loss_fn: callable object to use for calculating loss during training, e.g.
BCEWithLogitsLoss_hot()
optimizer_params: (dict) with "class" and "kwargs" keys for class.__init__(**kwargs)
lr_scheduler_params: (dict) with "class" and "kwargs" for class.__init__(**kwargs)
use_amp: (bool) if True, uses automatic mixed precision for training
wandb_logging: (dict) settings for logging to Weights and Biases
score_metric: (str) name of the metric for overall evaluation - one of the keys in self.torch_metrics
log_file: (str) path to save output to a text file
logging_level: (int) amt of logging to log file. 0 for nothing, 1,2,3 for increasing logged info
verbose: (int) amt of logging to stdout. 0 for nothing, 1,2,3 for increasing printed output
use_amp: (bool) if True, uses automatic mixed precision for training wandb_logging:
(dict) settings for logging to Weights and Biases
score_metric: (str) name of the metric for overall evaluation - one of the keys in
self.torch_metrics
log_file: (str) path to save output to a text file logging_level: (int) amt of logging
to log file. 0 for nothing, 1,2,3 for increasing logged info
verbose: (int) amt of logging to stdout. 0 for nothing, 1,2,3 for increasing printed
output
Other attributes:
torch_metrics: dictionary of torchmetrics name:object pairs to use for calculating metrics
torch_metrics: dictionary of torchmetrics name:object pairs to use for calculating
metrics
- override _init_torch_metrics() method in a subclass rather than modifying directly
- in general, if self.single_target is True, metrics will be called with
metric(predictions, labels) where predictions is shape (n_samples,n_classes) and
labels has integer labels with shape (n_samples,). If single_target is False,
instead labels are multi-hot encoded and have shape (n_samples,n_classes)
classes: list of class names
- set via __init__() or change_classes(), rather than modifying directly
- set the class list with __init__() or change_classes(), rather than modifying
directly
This ensures that other parameters like self.torch_metrics are updated accordingly
"""
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -1210,12 +1245,15 @@ def eval(self, targets=None, scores=None, reset_metrics=True):

# map is failing with memory limit on MPS, use CPU instead
# TODO: reconsider casting labels to int (support soft labels)
# if self.single_target, use argmax to get predicted class
# because torchmetrics Multiclass metrics expect class indices
y = targets.argmax(dim=1) if self.single_target else targets
for name, metric in self.torch_metrics.items():
device = (
self.device if self.device.type != "mps" else torch.device("cpu")
)
metrics[name] = metric.to(device)(
scores.detach().to(device), targets.detach().int().to(device)
scores.detach().to(device), y.detach().long().to(device)
).cpu()

if reset_metrics:
Expand Down Expand Up @@ -1252,12 +1290,12 @@ def run_validation(self, validation_df, progress_bar=True, **kwargs):
# run inference
validation_scores = self.predict(
validation_df,
activation_layer=("softmax_and_logit" if self.single_target else None),
activation_layer=("softmax" if self.single_target else "sigmoid"),
progress_bar=progress_bar,
**kwargs,
)

# if validation_df is a list of file paths, we need to generate clip-df with labels
# if validation_df index is file paths, we need to generate clip-df with labels
# to evaluate the scores. Easiest to do this with self.predict_dataloader()
dl = self.predict_dataloader(validation_df, **kwargs)
val_labels = dl.dataset.dataset.label_df.values
Expand Down
3 changes: 1 addition & 2 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def test_save_load_pickel(train_df):


def test_train_single_target(train_df):
model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0)
model.single_target = True
model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0, single_target=True)
model.train(
train_df,
train_df,
Expand Down

0 comments on commit b86b91b

Please sign in to comment.