-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added neural network recurrent transitions
- Loading branch information
1 parent
f6a554d
commit e51d6e0
Showing
4 changed files
with
187 additions
and
2 deletions.
There are no files selected for viewing
99 changes: 99 additions & 0 deletions
99
src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.ComponentModel; | ||
using System.Xml.Serialization; | ||
using Newtonsoft.Json; | ||
|
||
namespace Bonsai.ML.HiddenMarkovModels.Transitions | ||
{ | ||
/// <summary> | ||
/// Represents a constrained stationary transitions model. | ||
/// </summary> | ||
[JsonObject(MemberSerialization.OptIn)] | ||
public class NeuralNetworkRecurrentTransitions : TransitionsModel | ||
{ | ||
private int[] hiddenLayerSizes = [50]; | ||
private readonly string nonlinearity = "relu"; | ||
|
||
/// <summary> | ||
/// The sizes of the hidden layers. | ||
/// </summary> | ||
[XmlIgnore] | ||
public int[] HiddenLayerSizes { get => hiddenLayerSizes; set {hiddenLayerSizes = value; UpdateString();} } | ||
|
||
/// <summary> | ||
/// The sizes of the hidden layers. | ||
/// </summary> | ||
public string Nonlinearity => nonlinearity; | ||
|
||
/// <summary> | ||
/// The Log Ps of the transitions. | ||
/// </summary> | ||
public double[,] LogPs { get; private set; } = null; | ||
|
||
/// <summary> | ||
/// The weights of the transitions. | ||
/// </summary> | ||
public double[,,] Weights { get; private set; } = null; | ||
|
||
/// <summary> | ||
/// The biases of the transitions. | ||
/// </summary> | ||
public double[,,] Biases { get; private set; } = null; | ||
|
||
/// <inheritdoc/> | ||
[JsonProperty] | ||
[JsonConverter(typeof(TransitionsTypeJsonConverter))] | ||
public override TransitionsType TransitionsType => TransitionsType.NeuralNetworkRecurrent; | ||
|
||
/// <inheritdoc/> | ||
[JsonProperty] | ||
public override object[] Params | ||
{ | ||
get => [LogPs, Weights, Biases]; | ||
set | ||
{ | ||
LogPs = (double[,])value[0]; | ||
Weights = (double[,,])value[1]; | ||
Biases = (double[,,])value[2]; | ||
UpdateString(); | ||
} | ||
} | ||
|
||
/// <inheritdoc/> | ||
[JsonProperty] | ||
public override Dictionary<string, object> Kwargs => new Dictionary<string, object> | ||
{ | ||
["hidden_layer_sizes"] = HiddenLayerSizes, | ||
["nonlinearity"] = Nonlinearity, | ||
}; | ||
|
||
/// <summary> | ||
/// Initializes a new instance of the <see cref="NeuralNetworkRecurrentTransitions"/> class. | ||
/// </summary> | ||
public NeuralNetworkRecurrentTransitions(params object[] args) | ||
{ | ||
if (args is not null && args.Length > 0) | ||
{ | ||
HiddenLayerSizes = args[0] switch | ||
{ | ||
int[] layers => layers, | ||
long[] layers => ConvertLongArrayToIntArray(layers), | ||
_ => null | ||
}; | ||
UpdateString(); | ||
} | ||
} | ||
|
||
private static int[] ConvertLongArrayToIntArray(long[] longArray) | ||
{ | ||
int count = longArray.Length; | ||
int[] intArray = new int[count]; | ||
|
||
for (int i = 0; i < count; i++) | ||
intArray[i] = Convert.ToInt32(longArray[i]); | ||
|
||
return intArray; | ||
} | ||
} | ||
} |
80 changes: 80 additions & 0 deletions
80
src/Bonsai.ML.HiddenMarkovModels/Transitions/NeuralNetworkRecurrentTransitionsModel.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
using System; | ||
using Python.Runtime; | ||
using System.Reactive.Linq; | ||
using System.ComponentModel; | ||
using Newtonsoft.Json; | ||
using System.Xml.Serialization; | ||
|
||
namespace Bonsai.ML.HiddenMarkovModels.Transitions | ||
{ | ||
/// <summary> | ||
/// Represents an operator that is used to create and transform an observable sequence | ||
/// of <see cref="NeuralNetworkRecurrentTransitions"/> objects. | ||
/// </summary> | ||
[Combinator] | ||
[Description("Creates an observable sequence of NeuralNetworkRecurrentTransitions objects.")] | ||
[WorkflowElementCategory(ElementCategory.Source)] | ||
[JsonObject(MemberSerialization.OptIn)] | ||
public class NeuralNetworkRecurrentTransitionsModel | ||
{ | ||
/// <summary> | ||
/// The sizes of the hidden layers. | ||
/// </summary> | ||
[XmlIgnore] | ||
[Description("The sizes of the hidden layers.")] | ||
[JsonProperty] | ||
public int[] HiddenLayerSizes { get; set; } = [50]; | ||
|
||
/// <summary> | ||
/// The Log Ps of the transitions. | ||
/// </summary> | ||
[XmlIgnore] | ||
[Description("The log Ps of the transitions.")] | ||
public double[,] LogPs { get; set; } = null; | ||
|
||
/// <summary> | ||
/// The weights. | ||
/// </summary> | ||
[XmlIgnore] | ||
[Description("The weights.")] | ||
public double[,,] Weights { get; set; } = null; | ||
|
||
/// <summary> | ||
/// The biases. | ||
/// </summary> | ||
[XmlIgnore] | ||
[Description("The biases.")] | ||
public double[,,] Biases { get; set; } = null; | ||
|
||
/// <summary> | ||
/// Returns an observable sequence of <see cref="NeuralNetworkRecurrentTransitions"/> objects. | ||
/// </summary> | ||
new public IObservable<NeuralNetworkRecurrentTransitions> Process() | ||
{ | ||
return Observable.Return( | ||
new NeuralNetworkRecurrentTransitions(HiddenLayerSizes) | ||
{ | ||
Params = [LogPs, Weights, Biases] | ||
}); | ||
} | ||
|
||
/// <summary> | ||
/// Transforms an observable sequence of <see cref="PyObject"/> into an observable sequence | ||
/// of <see cref="NeuralNetworkRecurrentTransitions"/> objects by accessing internal attributes of the <see cref="PyObject"/>. | ||
/// </summary> | ||
new public IObservable<NeuralNetworkRecurrentTransitions> Process(IObservable<PyObject> source) | ||
{ | ||
return Observable.Select(source, pyObject => | ||
{ | ||
var logPsPyObj = (double[,])pyObject.GetArrayAttr("log_Ps"); | ||
var weightsPyObj = (double[,])pyObject.GetArrayAttr("weights"); | ||
var biasesPyObj = (double[,])pyObject.GetArrayAttr("biases"); | ||
|
||
return new NeuralNetworkRecurrentTransitions() | ||
{ | ||
Params = [logPsPyObj, weightsPyObj, biasesPyObj] | ||
}; | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters