Skip to content

Commit

Permalink
Added neural network recurrent transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
ncguilbeault committed Jul 10, 2024
1 parent f6a554d commit e51d6e0
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Xml.Serialization;
using Newtonsoft.Json;

namespace Bonsai.ML.HiddenMarkovModels.Transitions
{
/// <summary>
/// Represents a constrained stationary transitions model.
/// </summary>
[JsonObject(MemberSerialization.OptIn)]
public class NeuralNetworkRecurrentTransitions : TransitionsModel
{
private int[] hiddenLayerSizes = [50];
private readonly string nonlinearity = "relu";

/// <summary>
/// The sizes of the hidden layers.
/// </summary>
[XmlIgnore]
public int[] HiddenLayerSizes { get => hiddenLayerSizes; set {hiddenLayerSizes = value; UpdateString();} }

/// <summary>
/// The sizes of the hidden layers.
/// </summary>
public string Nonlinearity => nonlinearity;

/// <summary>
/// The Log Ps of the transitions.
/// </summary>
public double[,] LogPs { get; private set; } = null;

/// <summary>
/// The weights of the transitions.
/// </summary>
public double[,,] Weights { get; private set; } = null;

/// <summary>
/// The biases of the transitions.
/// </summary>
public double[,,] Biases { get; private set; } = null;

/// <inheritdoc/>
[JsonProperty]
[JsonConverter(typeof(TransitionsTypeJsonConverter))]
public override TransitionsType TransitionsType => TransitionsType.NeuralNetworkRecurrent;

/// <inheritdoc/>
[JsonProperty]
public override object[] Params
{
get => [LogPs, Weights, Biases];
set
{
LogPs = (double[,])value[0];
Weights = (double[,,])value[1];
Biases = (double[,,])value[2];
UpdateString();
}
}

/// <inheritdoc/>
[JsonProperty]
public override Dictionary<string, object> Kwargs => new Dictionary<string, object>
{
["hidden_layer_sizes"] = HiddenLayerSizes,
["nonlinearity"] = Nonlinearity,
};

/// <summary>
/// Initializes a new instance of the <see cref="NeuralNetworkRecurrentTransitions"/> class.
/// </summary>
public NeuralNetworkRecurrentTransitions(params object[] args)
{
if (args is not null && args.Length > 0)
{
HiddenLayerSizes = args[0] switch
{
int[] layers => layers,
long[] layers => ConvertLongArrayToIntArray(layers),
_ => null
};
UpdateString();
}
}

private static int[] ConvertLongArrayToIntArray(long[] longArray)
{
int count = longArray.Length;
int[] intArray = new int[count];

for (int i = 0; i < count; i++)
intArray[i] = Convert.ToInt32(longArray[i]);

return intArray;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using System;
using Python.Runtime;
using System.Reactive.Linq;
using System.ComponentModel;
using Newtonsoft.Json;
using System.Xml.Serialization;

namespace Bonsai.ML.HiddenMarkovModels.Transitions
{
/// <summary>
/// Represents an operator that is used to create and transform an observable sequence
/// of <see cref="NeuralNetworkRecurrentTransitions"/> objects.
/// </summary>
[Combinator]
[Description("Creates an observable sequence of NeuralNetworkRecurrentTransitions objects.")]
[WorkflowElementCategory(ElementCategory.Source)]
[JsonObject(MemberSerialization.OptIn)]
public class NeuralNetworkRecurrentTransitionsModel
{
/// <summary>
/// The sizes of the hidden layers.
/// </summary>
[XmlIgnore]
[Description("The sizes of the hidden layers.")]
[JsonProperty]
public int[] HiddenLayerSizes { get; set; } = [50];

/// <summary>
/// The Log Ps of the transitions.
/// </summary>
[XmlIgnore]
[Description("The log Ps of the transitions.")]
public double[,] LogPs { get; set; } = null;

/// <summary>
/// The weights.
/// </summary>
[XmlIgnore]
[Description("The weights.")]
public double[,,] Weights { get; set; } = null;

/// <summary>
/// The biases.
/// </summary>
[XmlIgnore]
[Description("The biases.")]
public double[,,] Biases { get; set; } = null;

/// <summary>
/// Returns an observable sequence of <see cref="NeuralNetworkRecurrentTransitions"/> objects.
/// </summary>
new public IObservable<NeuralNetworkRecurrentTransitions> Process()
{
return Observable.Return(
new NeuralNetworkRecurrentTransitions(HiddenLayerSizes)
{
Params = [LogPs, Weights, Biases]
});
}

/// <summary>
/// Transforms an observable sequence of <see cref="PyObject"/> into an observable sequence
/// of <see cref="NeuralNetworkRecurrentTransitions"/> objects by accessing internal attributes of the <see cref="PyObject"/>.
/// </summary>
new public IObservable<NeuralNetworkRecurrentTransitions> Process(IObservable<PyObject> source)
{
return Observable.Select(source, pyObject =>
{
var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps");
var weightsPyObj = (double[,])pyObject.GetArrayAttr("weights");
var biasesPyObj = (double[,])pyObject.GetArrayAttr("biases");

return new NeuralNetworkRecurrentTransitions()
{
Params = [logPsPyObj, weightsPyObj, biasesPyObj]
};
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public static class TransitionsLookup
{
{ TransitionsType.Stationary, (typeof(StationaryTransitions), "stationary") },
{ TransitionsType.ConstrainedStationary, (typeof(ConstrainedStationaryTransitions), "constrained") },
{ TransitionsType.Sticky, (typeof(StickyTransitions), "sticky") }
{ TransitionsType.Sticky, (typeof(StickyTransitions), "sticky") },
{ TransitionsType.NeuralNetworkRecurrent, (typeof(NeuralNetworkRecurrentTransitions), "nn_recurrent") }
};

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ public enum TransitionsType
/// <summary>
/// Sticky transitions.
/// </summary>
Sticky
Sticky,

/// <summary>
/// Neural network recurrent transitions.
/// </summary>
NeuralNetworkRecurrent
}
}

0 comments on commit e51d6e0

Please sign in to comment.