Skip to content

Commit

Permalink
consolidated the two versions of yo matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
rjgriffin42 committed Oct 21, 2024
1 parent 0e82749 commit d5108a8
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package us.ihmc.yoVariables.filters;

import org.ejml.data.DMatrix;
import org.ejml.data.DMatrixRMaj;
import org.ejml.data.Matrix;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.yoVariables.math.YoMatrix;
import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoDouble;

Expand Down Expand Up @@ -44,7 +45,7 @@ public void setAlpha(double alpha)
* @param current the current value of the matrix to be filtered. Not modified.
*/
@Override
public void set(DMatrix current)
public void set(Matrix current)
{
super.set(current);
this.current.set(current);
Expand All @@ -53,7 +54,7 @@ public void set(DMatrix current)
/**
* Assuming that the current value has been set, this method solves for the filtered value.
* <p>
* See {@link #set(DMatrix)} for how to set the matrix's current value.
* See {@link #set(Matrix)} for how to set the matrix's current value.
* </p>
*/
public void solve()
Expand All @@ -73,7 +74,7 @@ public void solve()
*
* @param current the current value of the matrix to be filtered. Not modified.
*/
public void setAndSolve(DMatrix current)
public void setAndSolve(Matrix current)
{
CommonOps_DDRM.scale(alpha.getDoubleValue(), previous, filtered);

Expand Down
64 changes: 50 additions & 14 deletions src/main/java/us/ihmc/yoVariables/math/YoMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
public class YoMatrix implements DMatrix, ReshapeMatrix
{
private static final long serialVersionUID = 2156411740647948028L;

private final int maxNumberOfRows, maxNumberOfColumns;

private final YoInteger numberOfRows, numberOfColumns;
Expand Down Expand Up @@ -144,14 +146,43 @@ public YoMatrix(String name, String description, int maxNumberOfRows, int maxNum
{
switch (checkNames(rowNames, columnNames))
{
case NONE -> variables[row][column] = new YoDouble(name + "_" + row + "_" + column, description, registry);
case ROWS -> variables[row][column] = new YoDouble(name + rowNames[row], description, registry);
case ROWS_AND_COLUMNS -> variables[row][column] = new YoDouble(name + rowNames[row] + columnNames[column], description, registry);
case NONE:
{
variables[row][column] = new YoDouble(getFieldName(name, row, column), description, registry);
variables[row][column].setToNaN();
break;
}
case ROWS:
{
if (maxNumberOfColumns > 1)
throw new IllegalArgumentException(
"The YoMatrix must be a column vector if only row names are provided, else unique names cannot be generated.");

variables[row][column] = new YoDouble(getFieldName(name, rowNames[row], ""), description, registry);
variables[row][column].setToNaN();
break;
}
case ROWS_AND_COLUMNS:
{
variables[row][column] = new YoDouble(getFieldName(name, rowNames[row], columnNames[column]), description, registry);
variables[row][column].setToNaN();
break;
}
}
}
}
}

public static String getFieldName(String prefix, int row, int column)
{
return getFieldName(prefix, "_" + row, "_" + column);
}

public static String getFieldName(String prefix, String rowName, String columName)
{
return prefix + rowName + columName;
}

/**
* Enum used to determine what names have been provided to the YoMatrix.
*/
Expand Down Expand Up @@ -212,7 +243,7 @@ public void scale(double scale)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, unsafe_get(row, col) * scale);
unsafe_set(row, col, unsafe_get(row, col) * scale, false);
}
}
}
Expand All @@ -233,7 +264,7 @@ public void scale(double scale, DMatrix matrix)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, matrix.unsafe_get(row, col) * scale);
unsafe_set(row, col, matrix.unsafe_get(row, col) * scale, false);
}
}
}
Expand Down Expand Up @@ -283,7 +314,7 @@ public void add(double alpha, DMatrix a, double beta, DMatrix b)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, alpha * a.unsafe_get(row, col) + beta * b.unsafe_get(row, col));
unsafe_set(row, col, alpha * a.unsafe_get(row, col) + beta * b.unsafe_get(row, col), false);
}
}
}
Expand Down Expand Up @@ -314,7 +345,7 @@ public void addEquals(double alpha, DMatrix a)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, unsafe_get(row, col) + alpha * a.unsafe_get(row, col));
unsafe_set(row, col, unsafe_get(row, col) + alpha * a.unsafe_get(row, col), false);
}
}
}
Expand Down Expand Up @@ -398,15 +429,15 @@ else if (numRows < 0 || numCols < 0)
{
for (int col = numCols; col < maxNumberOfColumns; col++)
{
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}

for (int row = numRows; row < maxNumberOfRows; row++)
{
for (int col = 0; col < maxNumberOfColumns; col++)
{
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}
}
Expand All @@ -423,7 +454,7 @@ public void set(int row, int col, double val)
{
if (col < 0 || col >= getNumCols() || row < 0 || row >= getNumRows())
throw new IllegalArgumentException("Specified element is out of bounds: (" + row + " , " + col + ")");
unsafe_set(row, col, val);
unsafe_set(row, col, val, false);
}

@Override
Expand All @@ -432,6 +463,11 @@ public void unsafe_set(int row, int col, double val)
variables[row][col].set(val);
}

private void unsafe_set(int row, int col, double val, boolean notifyListeners)
{
variables[row][col].set(val, notifyListeners);
}

/**
* Set {@code this} to the matrix {@code original}.
* <p>
Expand All @@ -450,7 +486,7 @@ public void set(Matrix original)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, otherMatrix.unsafe_get(row, col));
unsafe_set(row, col, otherMatrix.unsafe_get(row, col), false);
}
}
}
Expand All @@ -473,7 +509,7 @@ public void setToNaN(int numRows, int numCols)
{
for (int col = 0; col < numCols; col++)
{
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}
}
Expand Down Expand Up @@ -507,9 +543,9 @@ public void zero()
for (int col = 0; col < maxNumberOfColumns; col++)
{
if (row < getNumRows() && col < getNumCols())
unsafe_set(row, col, 0.0);
unsafe_set(row, col, 0.0, false);
else
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}
}
Expand Down
Loading

0 comments on commit d5108a8

Please sign in to comment.