From b12756b80f2f1aa5e3a86981d0642ae25c41bd05 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 10 Jul 2024 01:21:12 +0100 Subject: [PATCH] Updated classes to inherit from python string builder --- .../ModelParameters.cs | 52 +++++++++++++------ .../AutoRegressiveObservations.cs | 33 +++++------- .../AutoRegressiveObservationsModel.cs | 2 +- .../Observations/BernoulliObservations.cs | 20 +++---- .../BernoulliObservationsModel.cs | 2 +- .../Observations/ExponentialObservations.cs | 16 ++---- .../ExponentialObservationsModel.cs | 4 +- .../Observations/GaussianObservations.cs | 13 +---- .../Observations/GaussianObservationsModel.cs | 4 +- .../Observations/ObservationsModel.cs | 48 +++++++++++++++-- .../Observations/ObservationsModelBuilder.cs | 6 --- .../Observations/PoissonObservations.cs | 31 +---------- .../Observations/PoissonObservationsModel.cs | 14 ++--- .../StateParameters.cs | 39 +++++++++++--- 14 files changed, 146 insertions(+), 138 deletions(-) delete mode 100644 src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelBuilder.cs diff --git a/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs b/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs index 4ccfe7e2..a94bae28 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs @@ -3,8 +3,8 @@ using System.ComponentModel; using System.Linq; using System.Reactive.Linq; +using System.Text; using Newtonsoft.Json; -using Newtonsoft.Json.Converters; using System.Xml.Serialization; using Python.Runtime; using Bonsai.ML.HiddenMarkovModels.Observations; @@ -15,7 +15,7 @@ namespace Bonsai.ML.HiddenMarkovModels [Combinator] [Description("")] [WorkflowElementCategory(ElementCategory.Source)] - public class ModelParameters + public class ModelParameters : PythonStringBuilder { private int numStates; @@ -29,7 +29,7 @@ public class ModelParameters [JsonProperty("num_states")] [Description("The number of discrete latent states of the HMM model")] [Category("InitialParameters")] - public int NumStates { get => numStates; set => numStates = value; } + public int NumStates { get => numStates; set { numStates = value; UpdateString(); } } /// /// The dimensionality of the observations into the HMM model. @@ -37,7 +37,7 @@ public class ModelParameters [JsonProperty("dimensions")] [Description("The dimensionality of the observations into the HMM model")] [Category("InitialParameters")] - public int Dimensions { get => dimensions; set => dimensions = value; } + public int Dimensions { get => dimensions; set { dimensions = value; UpdateString(); } } /// /// The type of distribution that the HMM will use to model the emission of data observations. @@ -46,7 +46,7 @@ public class ModelParameters [JsonConverter(typeof(ObservationsTypeJsonConverter))] [Description("The type of distribution that the HMM will use to model the emission of data observations.")] [Category("InitialParameters")] - public ObservationsType ObservationsType { get => observationsType; set => observationsType = value; } + public ObservationsType ObservationsType { get => observationsType; set { observationsType = value; UpdateString(); } } /// /// The state parameters of the HMM model. @@ -54,7 +54,16 @@ public class ModelParameters [XmlIgnore] [Description("The state parameters of the HMM model.")] [Category("ModelState")] - public StateParameters StateParameters { get => stateParameters; set => stateParameters = value; } + public StateParameters StateParameters + { + get => stateParameters; + set + { + stateParameters = value; + ObservationsType = stateParameters.Observations.ObservationsType; + UpdateString(); + } + } /// /// Initializes a new instance of the class. @@ -98,17 +107,17 @@ public IObservable Process(IObservable source) var stateParametersObservable = new StateParameters().Process(sharedSource); return sharedSource.Select(pyObject => { - var numStatesPyObj = pyObject.GetAttr("num_states"); - var dimensionsPyObj = pyObject.GetAttr("dimensions"); + NumStates = pyObject.GetAttr("num_states"); + Dimensions = pyObject.GetAttr("dimensions"); var observationsTypeStrPyObj = pyObject.GetAttr("observation_type"); - var observationsTypePyObj = GetFromString(observationsTypeStrPyObj); + ObservationsType = GetFromString(observationsTypeStrPyObj); return new ModelParameters() { - NumStates = numStatesPyObj, - Dimensions = dimensionsPyObj, - ObservationsType = observationsTypePyObj, + NumStates = NumStates, + Dimensions = Dimensions, + ObservationsType = ObservationsType, }; }).Zip(stateParametersObservable, (modelParameters, stateParameters) => { @@ -117,12 +126,21 @@ public IObservable Process(IObservable source) }); } - public override string ToString() + protected override string BuildString() { - return $"num_states={numStates}," + - $"dimensions={dimensions}," + - $"observation_type=\"{GetString(observationsType)}\"," + - $"{(stateParameters == null ? "" : stateParameters)}"; + StringBuilder.Clear(); + StringBuilder.Append($"num_states={numStates},") + .Append($"dimensions={dimensions},"); + + if (stateParameters == null) + { + StringBuilder.Append($"observation_type=\"{GetString(observationsType)}\""); + } + else + { + StringBuilder.Append($"{stateParameters}"); + } + return StringBuilder.ToString(); } } } diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs index 2701cf13..cbc39344 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservations.cs @@ -1,8 +1,5 @@ -using System; -using Python.Runtime; -using System.Reactive.Linq; using System.ComponentModel; -using System.Xml.Serialization; +using System.Collections.Generic; using Newtonsoft.Json; namespace Bonsai.ML.HiddenMarkovModels.Observations @@ -11,12 +8,14 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations public class AutoRegressiveObservations : ObservationsModel { + private int lags = 1; + /// /// The lags of the observations for each state. /// [JsonProperty] [Description("The lags of the observations for each state.")] - public int Lags { get; set; } = 1; + public int Lags { get => lags; set {lags = value; UpdateString(); } } /// /// The As of the observations for each state. @@ -50,32 +49,28 @@ public class AutoRegressiveObservations : ObservationsModel [JsonProperty] public override object[] Params { - get { return [ As, Bs, Vs, SqrtSigmas ]; } + get =>[ As, Bs, Vs, SqrtSigmas ]; set { As = (double[,,])value[0]; Bs = (double[,])value[1]; Vs = (double[,,])value[2]; SqrtSigmas = (double[,,])value[3]; + UpdateString(); } } - public AutoRegressiveObservations (params object[] args) - { - Lags = (int)args[0]; - } - /// - public override string ToString() + [JsonProperty] + public override Dictionary Kwargs => new Dictionary { - if (As is null || Bs is null || Vs is null || SqrtSigmas is null) - return $"observation_params=None,observation_kwargs={{'lags':{Lags}}}"; + ["lags"] = Lags, + }; - return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(As)}," + - $"{NumpyHelper.NumpyParser.ParseArray(Bs)}," + - $"{NumpyHelper.NumpyParser.ParseArray(Vs)}," + - $"{NumpyHelper.NumpyParser.ParseArray(SqrtSigmas)})," + - $"observation_kwargs={{'lags':{Lags}}}"; + public AutoRegressiveObservations (params object[] args) + { + Lags = (int)args[0]; + UpdateString(); } } } \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservationsModel.cs index 15f0779e..4ea95582 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservationsModel.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/AutoRegressiveObservationsModel.cs @@ -11,7 +11,7 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations [Description("")] [WorkflowElementCategory(ElementCategory.Source)] [JsonObject(MemberSerialization.OptIn)] - public class AutoRegressiveObservationsModel : ObservationsModelBuilder + public class AutoRegressiveObservationsModel { /// /// The lags of the observations for each state. diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs index 2fa48384..e2b33035 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservations.cs @@ -1,9 +1,6 @@ -using System; -using Python.Runtime; -using System.Reactive.Linq; using System.ComponentModel; -using System.Xml.Serialization; using Newtonsoft.Json; +using static Bonsai.ML.HiddenMarkovModels.Observations.ObservationsLookup; namespace Bonsai.ML.HiddenMarkovModels.Observations { @@ -25,16 +22,11 @@ public class BernoulliObservations : ObservationsModel public override object[] Params { get { return [ LogitPs ]; } - set { LogitPs = (double[,])value[0]; } - } - - /// - public override string ToString() - { - if (LogitPs is null) - return $"observation_params=None"; - - return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(LogitPs)},)"; + set + { + LogitPs = (double[,])value[0]; + UpdateString(); + } } } } \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservationsModel.cs index 450da47a..529cce89 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservationsModel.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/BernoulliObservationsModel.cs @@ -11,7 +11,7 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations [Description("")] [WorkflowElementCategory(ElementCategory.Source)] [JsonObject(MemberSerialization.OptIn)] - public class BernoulliObservationsModel : ObservationsModelBuilder + public class BernoulliObservationsModel { /// /// The logit P of the observations for each state. diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs index 38e33ac5..536575a2 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservations.cs @@ -1,9 +1,7 @@ -using System; -using Python.Runtime; -using System.Reactive.Linq; using System.ComponentModel; -using System.Xml.Serialization; +using System.Text; using Newtonsoft.Json; +using static Bonsai.ML.HiddenMarkovModels.Observations.ObservationsLookup; namespace Bonsai.ML.HiddenMarkovModels.Observations { @@ -29,16 +27,8 @@ public override object[] Params set { LogLambdas = (double[,])value[0]; + UpdateString(); } } - - /// - public override string ToString() - { - if (LogLambdas is null) - return $"observation_params=None"; - - return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(LogLambdas)},)"; - } } } \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservationsModel.cs index decd52f7..7cbece9d 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservationsModel.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ExponentialObservationsModel.cs @@ -9,9 +9,9 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations { [Combinator] [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] + [WorkflowElementCategory(ElementCategory.Source)] [JsonObject(MemberSerialization.OptIn)] - public class ExponentialObservationsModel : ObservationsModelBuilder + public class ExponentialObservationsModel { /// diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs index 3478c4bc..2bd81948 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservations.cs @@ -32,22 +32,13 @@ public class GaussianObservations : ObservationsModel [JsonProperty] public override object[] Params { - get { return [ Mus, SqrtSigmas ]; } + get => [ Mus, SqrtSigmas ]; set { Mus = (double[,])value[0]; SqrtSigmas = (double[,,])value[1]; + UpdateString(); } } - - /// - public override string ToString() - { - if (Mus is null || SqrtSigmas is null) - return $"observation_params=None"; - - return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(Mus)}," + - $"{NumpyHelper.NumpyParser.ParseArray(SqrtSigmas)})"; - } } } \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsModel.cs index 804b68f8..3aba98af 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsModel.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsModel.cs @@ -10,9 +10,9 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations { [Combinator] [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] + [WorkflowElementCategory(ElementCategory.Source)] [JsonObject(MemberSerialization.OptIn)] - public class GaussianObservationsModel : ObservationsModelBuilder + public class GaussianObservationsModel { /// diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs index e5e058fb..f145a26c 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs @@ -1,12 +1,14 @@ -using Newtonsoft.Json; using System; +using System.Collections.Generic; +using System.Text; +using static Bonsai.ML.HiddenMarkovModels.Observations.ObservationsLookup; namespace Bonsai.ML.HiddenMarkovModels.Observations { /// /// An abstract class for creating an observations model. /// - public abstract class ObservationsModel + public abstract class ObservationsModel : PythonStringBuilder { /// /// The type of observations model. @@ -21,6 +23,46 @@ public abstract class ObservationsModel /// /// The keyword arguments that are used to construct the observations models. /// - public virtual object[] Kwargs => Array.Empty(); + public virtual Dictionary Kwargs => new(); + + protected override string BuildString() + { + StringBuilder.Clear(); + StringBuilder.Append($"observation_type=\"{GetString(ObservationsType)}\""); + + if (Params != null && Params.Length > 0) + { + var paramsStringBuilder = new StringBuilder(); + paramsStringBuilder.Append(",observation_params=("); + + foreach (var param in Params) { + if (param is null) { + paramsStringBuilder.Clear(); + break; + } + var arrString = param is Array array ? NumpyHelper.NumpyParser.ParseArray(array) : param.ToString(); + paramsStringBuilder.Append($"{arrString},"); + } + + if (paramsStringBuilder.Length > 0) { + paramsStringBuilder.Remove(paramsStringBuilder.Length - 1, 1); + paramsStringBuilder.Append(")"); + StringBuilder.Append(paramsStringBuilder); + } + + } + + if (Kwargs is not null && Kwargs.Count > 0) + { + StringBuilder.Append(",observation_kwargs={"); + foreach (var kp in Kwargs) { + StringBuilder.Append($"\"{kp.Key}\":{(kp.Value is Array ? NumpyHelper.NumpyParser.ParseArray((Array)kp.Value) : kp.Value)},"); + } + StringBuilder.Remove(StringBuilder.Length - 1, 1); + StringBuilder.Append("}"); + } + + return StringBuilder.ToString(); + } } } diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelBuilder.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelBuilder.cs deleted file mode 100644 index 0b239960..00000000 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModelBuilder.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bonsai.ML.HiddenMarkovModels.Observations -{ - public abstract class ObservationsModelBuilder where T : ObservationsModel - { - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs index 02a79e4b..e0597ea7 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservations.cs @@ -29,37 +29,8 @@ public override object[] Params set { LogLambdas = (double[,])value[0]; + UpdateString(); } } - - public IObservable Process() - { - return Observable.Return( - new PoissonObservations { - LogLambdas = LogLambdas - }); - } - - public IObservable Process(IObservable source) - { - return Observable.Select(source, pyObject => - { - var logLambdasPyObj = (double[,])pyObject.GetArrayAttr("log_lambdas"); - - return new PoissonObservations - { - LogLambdas = logLambdasPyObj - }; - }); - } - - /// - public override string ToString() - { - if (LogLambdas is null) - return $"observation_params=None"; - - return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(LogLambdas)},)"; - } } } \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservationsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservationsModel.cs index 1af8f77a..0d7ebf6f 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservationsModel.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/PoissonObservationsModel.cs @@ -9,8 +9,9 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations { [Combinator] [Description("")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class PoissonObservationsModel : ObservationsModelBuilder + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class PoissonObservationsModel { /// @@ -39,14 +40,5 @@ public IObservable Process(IObservable source) }; }); } - - /// - public override string ToString() - { - if (LogLambdas is null) - return $"observation_params=None"; - - return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(LogLambdas)},)"; - } } } \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs b/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs index 940ea355..190c4136 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/StateParameters.cs @@ -19,7 +19,7 @@ namespace Bonsai.ML.HiddenMarkovModels [JsonConverter(typeof(StateParametersJsonConverter))] [Description("StateParameters of a Hidden Markov Model (HMM).")] [WorkflowElementCategory(ElementCategory.Source)] - public class StateParameters + public class StateParameters : PythonStringBuilder { private double[] initialStateDistribution = null; @@ -36,7 +36,10 @@ public class StateParameters public double[] InitialStateDistribution { get => initialStateDistribution; - set => initialStateDistribution = value; + set { + initialStateDistribution = value; + UpdateString(); + } } /// @@ -49,7 +52,10 @@ public double[] InitialStateDistribution public double[,] LogTransitionProbabilities { get => logTransitionProbabilities; - set => logTransitionProbabilities = value; + set { + logTransitionProbabilities = value; + UpdateString(); + } } /// @@ -62,7 +68,10 @@ public double[] InitialStateDistribution public ObservationsModel Observations { get => observations; - set => observations = value; + set { + observations = value; + UpdateString(); + } } public IObservable Process() @@ -120,11 +129,25 @@ public IObservable Process(IObservable source) }); } - public override string ToString() + protected override string BuildString() { - return $"initial_state_distribution={(InitialStateDistribution == null ? "None" : NumpyHelper.NumpyParser.ParseArray(InitialStateDistribution))}," + - $"log_transition_probabilities={(LogTransitionProbabilities == null ? "None" : NumpyHelper.NumpyParser.ParseArray(LogTransitionProbabilities))}," + - $"{(Observations == null ? "" : Observations)}"; + StringBuilder.Clear(); + + if (InitialStateDistribution != null) { + StringBuilder.Append($"initial_state_distribution={NumpyHelper.NumpyParser.ParseArray(InitialStateDistribution)},"); + } + + if (LogTransitionProbabilities != null) { + StringBuilder.Append($"log_transition_probabilities={NumpyHelper.NumpyParser.ParseArray(LogTransitionProbabilities)},"); + } + + if (Observations != null) { + StringBuilder.Append($"{Observations},"); + } + + StringBuilder.Remove(StringBuilder.Length - 1, 1); + + return StringBuilder.ToString(); } } }