diff --git a/src/Bonsai.ML.HiddenMarkovModels/LoadHMMFromPkl.bonsai b/src/Bonsai.ML.HiddenMarkovModels/LoadHMMFromPkl.bonsai
new file mode 100644
index 0000000..a9eecb5
--- /dev/null
+++ b/src/Bonsai.ML.HiddenMarkovModels/LoadHMMFromPkl.bonsai
@@ -0,0 +1,120 @@
+
+
+
+
+
+
+
+
+
+ hmm
+
+
+
+ hmm
+
+
+ Name
+
+
+ Source1
+
+
+
+
+
+
+ model.pkl
+
+
+
+
+
+
+
+
+
+ {0}=HiddenMarkovModel.load_model("{1}")
+ Item1,Item2
+
+
+
+
+
+
+
+ HMMModule
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ HMMModule
+
+
+
+
+
+
+
+
+ hmm
+
+
+
+
+ 2
+ 2
+ Gaussian
+ Stationary
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/Bonsai.ML.HiddenMarkovModels/SaveHMMToPkl.bonsai b/src/Bonsai.ML.HiddenMarkovModels/SaveHMMToPkl.bonsai
new file mode 100644
index 0000000..8bad3a4
--- /dev/null
+++ b/src/Bonsai.ML.HiddenMarkovModels/SaveHMMToPkl.bonsai
@@ -0,0 +1,81 @@
+
+
+
+
+
+
+
+
+
+ hmm
+
+
+
+ hmm
+
+
+ Name
+
+
+ Source1
+
+
+
+
+
+
+ model.pkl
+
+
+
+
+
+
+
+
+
+ {0}.save_model("{1}")
+ Item1,Item2
+
+
+
+
+
+
+
+ HMMModule
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py
index acafea6..9e4a76c 100644
--- a/src/Bonsai.ML.HiddenMarkovModels/main.py
+++ b/src/Bonsai.ML.HiddenMarkovModels/main.py
@@ -11,6 +11,11 @@
npr.seed(0)
+class CustomUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'main' and name == 'HiddenMarkovModel':
+ return HiddenMarkovModel
+ return super().find_class(module, name)
class HiddenMarkovModel(HMM):
@@ -151,11 +156,14 @@ def compute_log_alpha(self, obs, log_alpha=None):
return log_alpha - logsumexp(log_alpha)
- def save_model_to_pickle(self, path: str):
- pickle.save(self, path)
-
- def load_model_from_pickle(self, path: str):
- self = pickle.load(path)
+ def save_model(self, path: str):
+ with open(path, "wb") as f:
+ pickle.dump(self, f)
+
+ @classmethod
+ def load_model(cls, path: str):
+ with open(path, 'rb') as f:
+ return CustomUnpickler(f).load()
def fit_async(self,
observation: list[float],