Skip to content

Commit

Permalink
Merge branch 'release/2.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Sep 15, 2021
2 parents b7e1183 + cb32dfe commit 51b3b8a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
5 changes: 5 additions & 0 deletions doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
Changelog
#########

Version 2.1 (2021-09-15)
~~~~~~~~~~~~~~~~~~~~~~~~

- feat: add Pipeline.training attribute

Version 2.0 (2020-11-25)
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
14 changes: 13 additions & 1 deletion pyannote/pipeline/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# The MIT License (MIT)

# Copyright (c) 2018-2020 CNRS
# Copyright (c) 2018-2021 CNRS

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -261,6 +261,9 @@ def tune(
['params'] nested dictionary of optimal parameters
"""

# pipeline is currently being optimized
self.pipeline.training = True

objective = self.get_objective(inputs, show_progress=show_progress)

if warm_start:
Expand All @@ -272,6 +275,9 @@ def tune(

self.study_.optimize(objective, n_trials=n_iterations, timeout=None, n_jobs=1)

# pipeline is no longer being optimized
self.pipeline.training = False

return {"loss": self.best_loss, "params": self.best_params}

def tune_iter(
Expand Down Expand Up @@ -311,6 +317,9 @@ def tune_iter(

while True:

# pipeline is currently being optimized
self.pipeline.training = True

# one trial at a time
self.study_.optimize(objective, n_trials=1, timeout=None, n_jobs=1)

Expand All @@ -320,4 +329,7 @@ def tune_iter(
except ValueError as e:
continue

# pipeline is no longer being optimized
self.pipeline.training = False

yield {"loss": best_loss, "params": best_params}
5 changes: 4 additions & 1 deletion pyannote/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# The MIT License (MIT)

# Copyright (c) 2018-2020 CNRS
# Copyright (c) 2018-2021 CNRS

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -57,6 +57,9 @@ def __init__(self):
# sub-pipelines
self._pipelines = OrderedDict()

# whether pipeline is currently being optimized
self.training = False

def __hash__(self):
# FIXME -- also keep track of (sub)pipeline attribtes
frozen = self.parameters(frozen=True)
Expand Down

0 comments on commit 51b3b8a

Please sign in to comment.