Skip to content

Commit

Permalink
pass dataloader kwargs to run_validation
Browse files Browse the repository at this point in the history
during refactor, did not pass these (incl batch_size and num_workers) to the method run_validation, resulting in use of batch_size=1 and num_workers=0 during validations in SpectrogramClassifier.train() loop
  • Loading branch information
sammlapp committed Sep 13, 2024
1 parent 3118648 commit 7a2ea1b
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,12 @@ def train(
if epoch % validation_interval == 0:
if validation_df is not None:
self._log("\nValidation.")
val_metrics = self.run_validation(validation_df)
val_metrics = self.run_validation(
validation_df,
batch_size=batch_size,
num_workers=num_workers,
raise_errors=raise_errors,
)
self.valid_metrics[self.current_epoch] = val_metrics
score = val_metrics[self.score_metric] # overall score
if wandb_session is not None:
Expand Down

0 comments on commit 7a2ea1b

Please sign in to comment.