Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get gaussian observation statistics #46

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ public override void Show(object value)
{
if (value is Observations.GaussianObservationsStatistics gaussianObservationsStatistics)
{

var statesCount = gaussianObservationsStatistics.Means.GetLength(0);
var observationDimensions = gaussianObservationsStatistics.Means.GetLength(1);

Expand Down Expand Up @@ -227,12 +226,14 @@ public override void Show(object value)

var batchObservationsCount = gaussianObservationsStatistics.BatchObservations.GetLength(0);
var offset = BufferData && batchObservationsCount > BufferCount ? batchObservationsCount - BufferCount : 0;
var predictedStatesCount = gaussianObservationsStatistics.PredictedStates.Length;

for (int i = offset; i < batchObservationsCount; i++)
{
var dim1 = gaussianObservationsStatistics.BatchObservations[i, dimension1SelectedIndex];
var dim2 = gaussianObservationsStatistics.BatchObservations[i, dimension2SelectedIndex];
var state = gaussianObservationsStatistics.InferredMostProbableStates[i];
allScatterSeries[(int)state].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state));
var state = gaussianObservationsStatistics.PredictedStates[i];
allScatterSeries[Convert.ToInt32(state)].Points.Add(new ScatterPoint(dim1, dim2, value: state, tag: state));
}

for (int i = 0; i < statesCount; i++)
Expand Down
2 changes: 1 addition & 1 deletion src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="py:Exec">
<py:Script>hmm.most_likely_states([59.7382107943162,3.99285183724331])</py:Script>
<py:Script>hmm.infer_state([59.7382107943162,3.99285183724331])</py:Script>
</Combinator>
</Expression>
<Expression xsi:type="WorkflowOutput" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ public class GaussianObservationsStatistics
public double[,] BatchObservations { get; set; }

/// <summary>
/// The sequence of inferred most probable states.
/// The predicted state for each observation in the batch of observations.
/// </summary>
[Description("The sequence of inferred most probable states.")]
[Description("The predicted state for each observation in the batch of observations.")]
[XmlIgnore]
public int[] InferredMostProbableStates { get; set; }
public long[] PredictedStates { get; set; }

/// <summary>
/// Transforms an observable sequence of <see cref="PyObject"/> into an observable sequence
Expand All @@ -64,15 +64,15 @@ public IObservable<GaussianObservationsStatistics> Process(IObservable<PyObject>
var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas");
var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj);
var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations");
var inferredMostProbableStatesPyObj = (int[])pyObject.GetArrayAttr("inferred_most_probable_states");
var predictedStatesPyObj = (long[])pyObject.GetArrayAttr("predicted_states");

return new GaussianObservationsStatistics
{
Means = meansPyObj,
StdDevs = stdDevsPyObj,
CovarianceMatrices = covarianceMatricesPyObj,
BatchObservations = batchObservationsPyObj,
InferredMostProbableStates = inferredMostProbableStatesPyObj
PredictedStates = predictedStatesPyObj
};
});
}
Expand Down
22 changes: 13 additions & 9 deletions src/Bonsai.ML.HiddenMarkovModels/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ def get_nonlinearity_type(func):
self.state_probabilities = None

self.batch = None
self.batch_observations = np.array([[]], dtype=float)
self.batch_observations = np.array([[]], dtype=float).reshape((0, dimensions))
self.is_running = False
self._fit_finished = False
self.loop = None
self.thread = None
self.curr_batch_size = 0
self.flush_data_between_batches = True
self.inferred_most_probable_states = np.array([], dtype=int)
self.predicted_states = np.array([], dtype=int)
self.buffer_count = 250

def update_params(self, initial_state_distribution, transitions_params, observations_params):
hmm_params = self.params
Expand Down Expand Up @@ -124,10 +125,17 @@ def update_params(self, initial_state_distribution, transitions_params, observat

def infer_state(self, observation: list[float]):

self.log_alpha = self.compute_log_alpha(
np.expand_dims(np.array(observation), 0), self.log_alpha)
observation = np.expand_dims(np.array(observation), 0)
self.log_alpha = self.compute_log_alpha(observation, self.log_alpha)
self.state_probabilities = np.exp(self.log_alpha).astype(np.double)
return self.state_probabilities.argmax()
prediction = self.state_probabilities.argmax()
self.predicted_states = np.append(self.predicted_states, prediction)
if self.predicted_states.shape[0] > self.buffer_count:
self.predicted_states = self.predicted_states[1:]
self.batch_observations = np.vstack([self.batch_observations, observation])
if self.batch_observations.shape[0] == self.buffer_count:
self.batch_observations = self.batch_observations[1:]
return prediction

def compute_log_alpha(self, obs, log_alpha=None):

Expand Down Expand Up @@ -171,8 +179,6 @@ def fit_async(self,
self.batch = np.vstack(
[self.batch[1:], np.expand_dims(np.array(observation), 0)])

self.batch_observations = self.batch

if not self.is_running and self.loop is None and self.thread is None:

if self.curr_batch_size >= batch_size:
Expand Down Expand Up @@ -221,8 +227,6 @@ def on_completion(future):
if self.flush_data_between_batches:
self.batch = None

self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)

self.is_running = True

if self.loop is None or self.loop.is_closed():
Expand Down