From b86b91b05765d634511a3c4492b8f8b64457db84 Mon Sep 17 00:00:00 2001 From: sammlapp Date: Tue, 17 Sep 2024 16:55:02 -0400 Subject: [PATCH] fix single-target eval() torchmetrics multiclass (single-target) metrics expect the labels as a vector of per-class integer label indices, rather than one-hot labels adds and reformats documentation of attributes adds test that catches the old, wrong behavior (before test was setting self.single_target=True after init, which did not update metrics to single-target metrics) --- opensoundscape/ml/cnn.py | 90 ++++++++++++++++++++++++++++------------ tests/test_cnn.py | 3 +- 2 files changed, 65 insertions(+), 28 deletions(-) diff --git a/opensoundscape/ml/cnn.py b/opensoundscape/ml/cnn.py index 1eb2a390..ce8e6b8e 100644 --- a/opensoundscape/ml/cnn.py +++ b/opensoundscape/ml/cnn.py @@ -280,10 +280,13 @@ def training_step(self, samples, batch_idx): # compute and log any metrics in self.torch_metrics # TODO: consider using validation set names rather than integer index # (would have to store a set of names for the validation set) + + # single-target torchmetrics expect labels as integer class indices rather than one-hot + y = batch_labels.argmax(dim=1) if self.single_target else batch_labels # TODO: does not allow soft labels, but some torchmetrics expect long type? batch_metrics = { f"train_{name}": metric.to(self.device)( - output.detach(), batch_labels.detach().long() + output.detach(), y.detach().long() ).cpu() for name, metric in self.torch_metrics.items() } @@ -840,50 +843,82 @@ class SpectrogramClassifier(SpectrogramModule, torch.nn.Module): def __init__(self, *args, **kwargs): """defines pure pytorch train, predict, and eval methods for a spectrogram classifier - subclasses SpectrogramModule, defines methods that are used for pure PyTorch workflow. - To use lightning, see ml.lightning.LightningSpectrogramModule. + subclasses SpectrogramModule, defines methods that are used for pure PyTorch workflow. To + use lightning, see ml.lightning.LightningSpectrogramModule. Args: see SpectrogramModule for arguments Methods: - predict: generate predictions across a set of audio files or a dataframe defining audio files and start/end clip times - train: fit the machine learning model using training data and evaluate with validation data - save: save the model to a file - load: load the model from a file + predict: generate predictions across a set of audio files or a dataframe defining audio + files and start/end clip times + + train: fit the machine learning model using training data and evaluate with validation + data + + save: save the model to a file load: load the model from a file + embed: generate embeddings for a set of audio files - generate_embeddings: generate embeddings for a set of audio files + + generate_samples: creates preprocessed sample tensors, same arguments as predict() + generate_cams: generate gradient activation maps for a set of audio files + eval: evaluate performance by applying self.torch_metrics to predictions and labels - run_validation: test accuracy by running inference on a validation set and computing metrics - change_classes: change the classes that the model predicts + + run_validation: test accuracy by running inference on a validation set and computing + + metrics change_classes: change the classes that the model predicts + freeze_feature_extractor: freeze all layers except the classifier + freeze_layers_except: freeze all parameters of a model, optionally exluding some layers - train_dataloader: create dataloader for training - predict_dataloader: create dataloader for inference (predict/validate/test) + + train_dataloader: create dataloader for training predict_dataloader: create dataloader + for inference (predict/validate/test) + save_weights: save just the self.network state dict to a file + load_weights: load just the self.network state dict from a file Editable Attributes & Properties: single_target: (bool) if True, predict only class with max score - device: (torch.device or str) device to use for training and inference - preprocessor: object defining preprocessing and augmentation operations, e.g. SpectrogramPreprocessor + + device: (torch.device or str) device to use for training and inference preprocessor: + object defining preprocessing and augmentation operations, e.g. SpectrogramPreprocessor + network: pytorch model object, e.g. Resnet18 - loss_fn: callable object to use for calculating loss during training, e.g. BCEWithLogitsLoss_hot() + + loss_fn: callable object to use for calculating loss during training, e.g. + BCEWithLogitsLoss_hot() + optimizer_params: (dict) with "class" and "kwargs" keys for class.__init__(**kwargs) + lr_scheduler_params: (dict) with "class" and "kwargs" for class.__init__(**kwargs) - use_amp: (bool) if True, uses automatic mixed precision for training - wandb_logging: (dict) settings for logging to Weights and Biases - score_metric: (str) name of the metric for overall evaluation - one of the keys in self.torch_metrics - log_file: (str) path to save output to a text file - logging_level: (int) amt of logging to log file. 0 for nothing, 1,2,3 for increasing logged info - verbose: (int) amt of logging to stdout. 0 for nothing, 1,2,3 for increasing printed output + use_amp: (bool) if True, uses automatic mixed precision for training wandb_logging: + (dict) settings for logging to Weights and Biases + + score_metric: (str) name of the metric for overall evaluation - one of the keys in + self.torch_metrics + + log_file: (str) path to save output to a text file logging_level: (int) amt of logging + to log file. 0 for nothing, 1,2,3 for increasing logged info + + verbose: (int) amt of logging to stdout. 0 for nothing, 1,2,3 for increasing printed + output Other attributes: - torch_metrics: dictionary of torchmetrics name:object pairs to use for calculating metrics + torch_metrics: dictionary of torchmetrics name:object pairs to use for calculating + metrics - override _init_torch_metrics() method in a subclass rather than modifying directly + - in general, if self.single_target is True, metrics will be called with + metric(predictions, labels) where predictions is shape (n_samples,n_classes) and + labels has integer labels with shape (n_samples,). If single_target is False, + instead labels are multi-hot encoded and have shape (n_samples,n_classes) classes: list of class names - - set via __init__() or change_classes(), rather than modifying directly + - set the class list with __init__() or change_classes(), rather than modifying + directly + This ensures that other parameters like self.torch_metrics are updated accordingly """ super().__init__(*args, **kwargs) @@ -1210,12 +1245,15 @@ def eval(self, targets=None, scores=None, reset_metrics=True): # map is failing with memory limit on MPS, use CPU instead # TODO: reconsider casting labels to int (support soft labels) + # if self.single_target, use argmax to get predicted class + # because torchmetrics Multiclass metrics expect class indices + y = targets.argmax(dim=1) if self.single_target else targets for name, metric in self.torch_metrics.items(): device = ( self.device if self.device.type != "mps" else torch.device("cpu") ) metrics[name] = metric.to(device)( - scores.detach().to(device), targets.detach().int().to(device) + scores.detach().to(device), y.detach().long().to(device) ).cpu() if reset_metrics: @@ -1252,12 +1290,12 @@ def run_validation(self, validation_df, progress_bar=True, **kwargs): # run inference validation_scores = self.predict( validation_df, - activation_layer=("softmax_and_logit" if self.single_target else None), + activation_layer=("softmax" if self.single_target else "sigmoid"), progress_bar=progress_bar, **kwargs, ) - # if validation_df is a list of file paths, we need to generate clip-df with labels + # if validation_df index is file paths, we need to generate clip-df with labels # to evaluate the scores. Easiest to do this with self.predict_dataloader() dl = self.predict_dataloader(validation_df, **kwargs) val_labels = dl.dataset.dataset.label_df.values diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 8c5cb857..5f1b2650 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -130,8 +130,7 @@ def test_save_load_pickel(train_df): def test_train_single_target(train_df): - model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0) - model.single_target = True + model = cnn.CNN(architecture="resnet18", classes=[0, 1], sample_duration=5.0, single_target=True) model.train( train_df, train_df,