From e51d6e0cf9f9b462ff9ca2eaf62f233b0a1ce4be Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 10 Jul 2024 21:57:44 +0100 Subject: [PATCH] Added neural network recurrent transitions --- .../NeuralNetworkRecurrentTransitions.cs | 99 +++++++++++++++++++ .../NeuralNetworkRecurrentTransitionsModel.cs | 80 +++++++++++++++ .../Transitions/TransitionsLookup.cs | 3 +- .../Transitions/TransitionsType.cs | 7 +- 4 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs create mode 100644 src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitionsModel.cs diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs new file mode 100644 index 00000000..3ecfa7b0 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Xml.Serialization; +using Newtonsoft.Json; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents a constrained stationary transitions model. + /// + [JsonObject(MemberSerialization.OptIn)] + public class NeuralNetworkRecurrentTransitions : TransitionsModel + { + private int[] hiddenLayerSizes = [50]; + private readonly string nonlinearity = "relu"; + + /// + /// The sizes of the hidden layers. + /// + [XmlIgnore] + public int[] HiddenLayerSizes { get => hiddenLayerSizes; set {hiddenLayerSizes = value; UpdateString();} } + + /// + /// The sizes of the hidden layers. + /// + public string Nonlinearity => nonlinearity; + + /// + /// The Log Ps of the transitions. + /// + public double[,] LogPs { get; private set; } = null; + + /// + /// The weights of the transitions. + /// + public double[,,] Weights { get; private set; } = null; + + /// + /// The biases of the transitions. + /// + public double[,,] Biases { get; private set; } = null; + + /// + [JsonProperty] + [JsonConverter(typeof(TransitionsTypeJsonConverter))] + public override TransitionsType TransitionsType => TransitionsType.NeuralNetworkRecurrent; + + /// + [JsonProperty] + public override object[] Params + { + get => [LogPs, Weights, Biases]; + set + { + LogPs = (double[,])value[0]; + Weights = (double[,,])value[1]; + Biases = (double[,,])value[2]; + UpdateString(); + } + } + + /// + [JsonProperty] + public override Dictionary Kwargs => new Dictionary + { + ["hidden_layer_sizes"] = HiddenLayerSizes, + ["nonlinearity"] = Nonlinearity, + }; + + /// + /// Initializes a new instance of the class. + /// + public NeuralNetworkRecurrentTransitions(params object[] args) + { + if (args is not null && args.Length > 0) + { + HiddenLayerSizes = args[0] switch + { + int[] layers => layers, + long[] layers => ConvertLongArrayToIntArray(layers), + _ => null + }; + UpdateString(); + } + } + + private static int[] ConvertLongArrayToIntArray(long[] longArray) + { + int count = longArray.Length; + int[] intArray = new int[count]; + + for (int i = 0; i < count; i++) + intArray[i] = Convert.ToInt32(longArray[i]); + + return intArray; + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitionsModel.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitionsModel.cs new file mode 100644 index 00000000..02c231e5 --- /dev/null +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitionsModel.cs @@ -0,0 +1,80 @@ +using System; +using Python.Runtime; +using System.Reactive.Linq; +using System.ComponentModel; +using Newtonsoft.Json; +using System.Xml.Serialization; + +namespace Bonsai.ML.HiddenMarkovModels.Transitions +{ + /// + /// Represents an operator that is used to create and transform an observable sequence + /// of objects. + /// + [Combinator] + [Description("Creates an observable sequence of NeuralNetworkRecurrentTransitions objects.")] + [WorkflowElementCategory(ElementCategory.Source)] + [JsonObject(MemberSerialization.OptIn)] + public class NeuralNetworkRecurrentTransitionsModel + { + /// + /// The sizes of the hidden layers. + /// + [XmlIgnore] + [Description("The sizes of the hidden layers.")] + [JsonProperty] + public int[] HiddenLayerSizes { get; set; } = [50]; + + /// + /// The Log Ps of the transitions. + /// + [XmlIgnore] + [Description("The log Ps of the transitions.")] + public double[,] LogPs { get; set; } = null; + + /// + /// The weights. + /// + [XmlIgnore] + [Description("The weights.")] + public double[,,] Weights { get; set; } = null; + + /// + /// The biases. + /// + [XmlIgnore] + [Description("The biases.")] + public double[,,] Biases { get; set; } = null; + + /// + /// Returns an observable sequence of objects. + /// + new public IObservable Process() + { + return Observable.Return( + new NeuralNetworkRecurrentTransitions(HiddenLayerSizes) + { + Params = [LogPs, Weights, Biases] + }); + } + + /// + /// Transforms an observable sequence of into an observable sequence + /// of objects by accessing internal attributes of the . + /// + new public IObservable Process(IObservable source) + { + return Observable.Select(source, pyObject => + { + var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps"); + var weightsPyObj = (double[,])pyObject.GetArrayAttr("weights"); + var biasesPyObj = (double[,])pyObject.GetArrayAttr("biases"); + + return new NeuralNetworkRecurrentTransitions() + { + Params = [logPsPyObj, weightsPyObj, biasesPyObj] + }; + }); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsLookup.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsLookup.cs index f909b0da..f6f23b11 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsLookup.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsLookup.cs @@ -14,7 +14,8 @@ public static class TransitionsLookup { { TransitionsType.Stationary, (typeof(StationaryTransitions), "stationary") }, { TransitionsType.ConstrainedStationary, (typeof(ConstrainedStationaryTransitions), "constrained") }, - { TransitionsType.Sticky, (typeof(StickyTransitions), "sticky") } + { TransitionsType.Sticky, (typeof(StickyTransitions), "sticky") }, + { TransitionsType.NeuralNetworkRecurrent, (typeof(NeuralNetworkRecurrentTransitions), "nn_recurrent") } }; /// diff --git a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsType.cs b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsType.cs index 649dbd5c..771e9024 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsType.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Transitions/TransitionsType.cs @@ -18,6 +18,11 @@ public enum TransitionsType /// /// Sticky transitions. /// - Sticky + Sticky, + + /// + /// Neural network recurrent transitions. + /// + NeuralNetworkRecurrent } } \ No newline at end of file