Skip to content

Commit

Permalink
Updated classes to inherit from python string builder
Browse files Browse the repository at this point in the history
  • Loading branch information
ncguilbeault committed Jul 10, 2024
1 parent 5ee0ec0 commit b12756b
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 138 deletions.
52 changes: 35 additions & 17 deletions src/Bonsai.ML.HiddenMarkovModels/ModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
using System.ComponentModel;
using System.Linq;
using System.Reactive.Linq;
using System.Text;
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using System.Xml.Serialization;
using Python.Runtime;
using Bonsai.ML.HiddenMarkovModels.Observations;
Expand All @@ -15,7 +15,7 @@ namespace Bonsai.ML.HiddenMarkovModels
[Combinator]
[Description("")]
[WorkflowElementCategory(ElementCategory.Source)]
public class ModelParameters
public class ModelParameters : PythonStringBuilder
{

private int numStates;
Expand All @@ -29,15 +29,15 @@ public class ModelParameters
[JsonProperty("num_states")]
[Description("The number of discrete latent states of the HMM model")]
[Category("InitialParameters")]
public int NumStates { get => numStates; set => numStates = value; }
public int NumStates { get => numStates; set { numStates = value; UpdateString(); } }

/// <summary>
/// The dimensionality of the observations into the HMM model.
/// </summary>
[JsonProperty("dimensions")]
[Description("The dimensionality of the observations into the HMM model")]
[Category("InitialParameters")]
public int Dimensions { get => dimensions; set => dimensions = value; }
public int Dimensions { get => dimensions; set { dimensions = value; UpdateString(); } }

/// <summary>
/// The type of distribution that the HMM will use to model the emission of data observations.
Expand All @@ -46,15 +46,24 @@ public class ModelParameters
[JsonConverter(typeof(ObservationsTypeJsonConverter))]
[Description("The type of distribution that the HMM will use to model the emission of data observations.")]
[Category("InitialParameters")]
public ObservationsType ObservationsType { get => observationsType; set => observationsType = value; }
public ObservationsType ObservationsType { get => observationsType; set { observationsType = value; UpdateString(); } }

/// <summary>
/// The state parameters of the HMM model.
/// </summary>
[XmlIgnore]
[Description("The state parameters of the HMM model.")]
[Category("ModelState")]
public StateParameters StateParameters { get => stateParameters; set => stateParameters = value; }
public StateParameters StateParameters
{
get => stateParameters;
set
{
stateParameters = value;
ObservationsType = stateParameters.Observations.ObservationsType;
UpdateString();
}
}

/// <summary>
/// Initializes a new instance of the <see cref="ModelParameters"/> class.
Expand Down Expand Up @@ -98,17 +107,17 @@ public IObservable<ModelParameters> Process(IObservable<PyObject> source)
var stateParametersObservable = new StateParameters().Process(sharedSource);
return sharedSource.Select(pyObject =>
{
var numStatesPyObj = pyObject.GetAttr<int>("num_states");
var dimensionsPyObj = pyObject.GetAttr<int>("dimensions");
NumStates = pyObject.GetAttr<int>("num_states");
Dimensions = pyObject.GetAttr<int>("dimensions");
var observationsTypeStrPyObj = pyObject.GetAttr<string>("observation_type");

var observationsTypePyObj = GetFromString(observationsTypeStrPyObj);
ObservationsType = GetFromString(observationsTypeStrPyObj);

return new ModelParameters()
{
NumStates = numStatesPyObj,
Dimensions = dimensionsPyObj,
ObservationsType = observationsTypePyObj,
NumStates = NumStates,
Dimensions = Dimensions,
ObservationsType = ObservationsType,
};
}).Zip(stateParametersObservable, (modelParameters, stateParameters) =>
{
Expand All @@ -117,12 +126,21 @@ public IObservable<ModelParameters> Process(IObservable<PyObject> source)
});
}

public override string ToString()
protected override string BuildString()
{
return $"num_states={numStates}," +
$"dimensions={dimensions}," +
$"observation_type=\"{GetString(observationsType)}\"," +
$"{(stateParameters == null ? "" : stateParameters)}";
StringBuilder.Clear();
StringBuilder.Append($"num_states={numStates},")
.Append($"dimensions={dimensions},");

if (stateParameters == null)
{
StringBuilder.Append($"observation_type=\"{GetString(observationsType)}\"");
}
else
{
StringBuilder.Append($"{stateParameters}");
}
return StringBuilder.ToString();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
using System;
using Python.Runtime;
using System.Reactive.Linq;
using System.ComponentModel;
using System.Xml.Serialization;
using System.Collections.Generic;
using Newtonsoft.Json;

namespace Bonsai.ML.HiddenMarkovModels.Observations
Expand All @@ -11,12 +8,14 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations
public class AutoRegressiveObservations : ObservationsModel
{

private int lags = 1;

/// <summary>
/// The lags of the observations for each state.
/// </summary>
[JsonProperty]
[Description("The lags of the observations for each state.")]
public int Lags { get; set; } = 1;
public int Lags { get => lags; set {lags = value; UpdateString(); } }

/// <summary>
/// The As of the observations for each state.
Expand Down Expand Up @@ -50,32 +49,28 @@ public class AutoRegressiveObservations : ObservationsModel
[JsonProperty]
public override object[] Params
{
get { return [ As, Bs, Vs, SqrtSigmas ]; }
get =>[ As, Bs, Vs, SqrtSigmas ];
set
{
As = (double[,,])value[0];
Bs = (double[,])value[1];
Vs = (double[,,])value[2];
SqrtSigmas = (double[,,])value[3];
UpdateString();
}
}

public AutoRegressiveObservations (params object[] args)
{
Lags = (int)args[0];
}

/// <inheritdoc/>
public override string ToString()
[JsonProperty]
public override Dictionary<string, object> Kwargs => new Dictionary<string, object>
{
if (As is null || Bs is null || Vs is null || SqrtSigmas is null)
return $"observation_params=None,observation_kwargs={{'lags':{Lags}}}";
["lags"] = Lags,
};

return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(As)}," +
$"{NumpyHelper.NumpyParser.ParseArray(Bs)}," +
$"{NumpyHelper.NumpyParser.ParseArray(Vs)}," +
$"{NumpyHelper.NumpyParser.ParseArray(SqrtSigmas)})," +
$"observation_kwargs={{'lags':{Lags}}}";
public AutoRegressiveObservations (params object[] args)
{
Lags = (int)args[0];
UpdateString();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations
[Description("")]
[WorkflowElementCategory(ElementCategory.Source)]
[JsonObject(MemberSerialization.OptIn)]
public class AutoRegressiveObservationsModel : ObservationsModelBuilder<AutoRegressiveObservations>
public class AutoRegressiveObservationsModel
{
/// <summary>
/// The lags of the observations for each state.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using System;
using Python.Runtime;
using System.Reactive.Linq;
using System.ComponentModel;
using System.Xml.Serialization;
using Newtonsoft.Json;
using static Bonsai.ML.HiddenMarkovModels.Observations.ObservationsLookup;

namespace Bonsai.ML.HiddenMarkovModels.Observations
{
Expand All @@ -25,16 +22,11 @@ public class BernoulliObservations : ObservationsModel
public override object[] Params
{
get { return [ LogitPs ]; }
set { LogitPs = (double[,])value[0]; }
}

/// <inheritdoc/>
public override string ToString()
{
if (LogitPs is null)
return $"observation_params=None";

return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(LogitPs)},)";
set
{
LogitPs = (double[,])value[0];
UpdateString();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations
[Description("")]
[WorkflowElementCategory(ElementCategory.Source)]
[JsonObject(MemberSerialization.OptIn)]
public class BernoulliObservationsModel : ObservationsModelBuilder<BernoulliObservations>
public class BernoulliObservationsModel
{
/// <summary>
/// The logit P of the observations for each state.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using System;
using Python.Runtime;
using System.Reactive.Linq;
using System.ComponentModel;
using System.Xml.Serialization;
using System.Text;
using Newtonsoft.Json;
using static Bonsai.ML.HiddenMarkovModels.Observations.ObservationsLookup;

namespace Bonsai.ML.HiddenMarkovModels.Observations
{
Expand All @@ -29,16 +27,8 @@ public override object[] Params
set
{
LogLambdas = (double[,])value[0];
UpdateString();
}
}

/// <inheritdoc/>
public override string ToString()
{
if (LogLambdas is null)
return $"observation_params=None";

return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(LogLambdas)},)";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations
{
[Combinator]
[Description("")]
[WorkflowElementCategory(ElementCategory.Transform)]
[WorkflowElementCategory(ElementCategory.Source)]
[JsonObject(MemberSerialization.OptIn)]
public class ExponentialObservationsModel : ObservationsModelBuilder<ExponentialObservations>
public class ExponentialObservationsModel
{

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,13 @@ public class GaussianObservations : ObservationsModel
[JsonProperty]
public override object[] Params
{
get { return [ Mus, SqrtSigmas ]; }
get => [ Mus, SqrtSigmas ];
set
{
Mus = (double[,])value[0];
SqrtSigmas = (double[,,])value[1];
UpdateString();
}
}

/// <inheritdoc/>
public override string ToString()
{
if (Mus is null || SqrtSigmas is null)
return $"observation_params=None";

return $"observation_params=({NumpyHelper.NumpyParser.ParseArray(Mus)}," +
$"{NumpyHelper.NumpyParser.ParseArray(SqrtSigmas)})";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace Bonsai.ML.HiddenMarkovModels.Observations
{
[Combinator]
[Description("")]
[WorkflowElementCategory(ElementCategory.Transform)]
[WorkflowElementCategory(ElementCategory.Source)]
[JsonObject(MemberSerialization.OptIn)]
public class GaussianObservationsModel : ObservationsModelBuilder<GaussianObservations>
public class GaussianObservationsModel
{

/// <summary>
Expand Down
48 changes: 45 additions & 3 deletions src/Bonsai.ML.HiddenMarkovModels/Observations/ObservationsModel.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using static Bonsai.ML.HiddenMarkovModels.Observations.ObservationsLookup;

namespace Bonsai.ML.HiddenMarkovModels.Observations
{
/// <summary>
/// An abstract class for creating an observations model.
/// </summary>
public abstract class ObservationsModel
public abstract class ObservationsModel : PythonStringBuilder
{
/// <summary>
/// The type of observations model.
Expand All @@ -21,6 +23,46 @@ public abstract class ObservationsModel
/// <summary>
/// The keyword arguments that are used to construct the observations models.
/// </summary>
public virtual object[] Kwargs => Array.Empty<object>();
public virtual Dictionary<string, object> Kwargs => new();

protected override string BuildString()
{
StringBuilder.Clear();
StringBuilder.Append($"observation_type=\"{GetString(ObservationsType)}\"");

if (Params != null && Params.Length > 0)
{
var paramsStringBuilder = new StringBuilder();
paramsStringBuilder.Append(",observation_params=(");

foreach (var param in Params) {
if (param is null) {
paramsStringBuilder.Clear();
break;
}
var arrString = param is Array array ? NumpyHelper.NumpyParser.ParseArray(array) : param.ToString();
paramsStringBuilder.Append($"{arrString},");
}

if (paramsStringBuilder.Length > 0) {
paramsStringBuilder.Remove(paramsStringBuilder.Length - 1, 1);
paramsStringBuilder.Append(")");
StringBuilder.Append(paramsStringBuilder);
}

}

if (Kwargs is not null && Kwargs.Count > 0)
{
StringBuilder.Append(",observation_kwargs={");
foreach (var kp in Kwargs) {
StringBuilder.Append($"\"{kp.Key}\":{(kp.Value is Array ? NumpyHelper.NumpyParser.ParseArray((Array)kp.Value) : kp.Value)},");
}
StringBuilder.Remove(StringBuilder.Length - 1, 1);
StringBuilder.Append("}");
}

return StringBuilder.ToString();
}
}
}

This file was deleted.

Loading

0 comments on commit b12756b

Please sign in to comment.