diff --git a/opensoundscape/ml/cnn.py b/opensoundscape/ml/cnn.py index 51401f1f..1eb2a390 100644 --- a/opensoundscape/ml/cnn.py +++ b/opensoundscape/ml/cnn.py @@ -1565,7 +1565,12 @@ def train( if epoch % validation_interval == 0: if validation_df is not None: self._log("\nValidation.") - val_metrics = self.run_validation(validation_df) + val_metrics = self.run_validation( + validation_df, + batch_size=batch_size, + num_workers=num_workers, + raise_errors=raise_errors, + ) self.valid_metrics[self.current_epoch] = val_metrics score = val_metrics[self.score_metric] # overall score if wandb_session is not None: