diff --git a/src/Bonsai.ML.HiddenMarkovModels/LoadHMMFromPkl.bonsai b/src/Bonsai.ML.HiddenMarkovModels/LoadHMMFromPkl.bonsai new file mode 100644 index 0000000..a9eecb5 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/LoadHMMFromPkl.bonsai @@ -0,0 +1,120 @@ + + + + + + + + + + hmm + + + + hmm + + + Name + + + Source1 + + + + + + + model.pkl + + + + + + + + + + {0}=HiddenMarkovModel.load_model("{1}") + Item1,Item2 + + + + + + + + HMMModule + + + + + + + + + + + + + + + + + + + + + HMMModule + + + + + + + + + hmm + + + + + 2 + 2 + Gaussian + Stationary + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/SaveHMMToPkl.bonsai b/src/Bonsai.ML.HiddenMarkovModels/SaveHMMToPkl.bonsai new file mode 100644 index 0000000..8bad3a4 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/SaveHMMToPkl.bonsai @@ -0,0 +1,81 @@ + + + + + + + + + + hmm + + + + hmm + + + Name + + + Source1 + + + + + + + model.pkl + + + + + + + + + + {0}.save_model("{1}") + Item1,Item2 + + + + + + + + HMMModule + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py index acafea6..9e4a76c 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/main.py +++ b/src/Bonsai.ML.HiddenMarkovModels/main.py @@ -11,6 +11,11 @@ npr.seed(0) +class CustomUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'main' and name == 'HiddenMarkovModel': + return HiddenMarkovModel + return super().find_class(module, name) class HiddenMarkovModel(HMM): @@ -151,11 +156,14 @@ def compute_log_alpha(self, obs, log_alpha=None): return log_alpha - logsumexp(log_alpha) - def save_model_to_pickle(self, path: str): - pickle.save(self, path) - - def load_model_from_pickle(self, path: str): - self = pickle.load(path) + def save_model(self, path: str): + with open(path, "wb") as f: + pickle.dump(self, f) + + @classmethod + def load_model(cls, path: str): + with open(path, 'rb') as f: + return CustomUnpickler(f).load() def fit_async(self, observation: list[float],