Skip to content

Commit

Permalink
consolidated the two versions of yo matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
rjgriffin42 committed Oct 20, 2024
1 parent 0e82749 commit e034cc4
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 18 deletions.
64 changes: 50 additions & 14 deletions src/main/java/us/ihmc/yoVariables/math/YoMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
public class YoMatrix implements DMatrix, ReshapeMatrix
{
private static final long serialVersionUID = 2156411740647948028L;

private final int maxNumberOfRows, maxNumberOfColumns;

private final YoInteger numberOfRows, numberOfColumns;
Expand Down Expand Up @@ -144,14 +146,43 @@ public YoMatrix(String name, String description, int maxNumberOfRows, int maxNum
{
switch (checkNames(rowNames, columnNames))
{
case NONE -> variables[row][column] = new YoDouble(name + "_" + row + "_" + column, description, registry);
case ROWS -> variables[row][column] = new YoDouble(name + rowNames[row], description, registry);
case ROWS_AND_COLUMNS -> variables[row][column] = new YoDouble(name + rowNames[row] + columnNames[column], description, registry);
case NONE:
{
variables[row][column] = new YoDouble(getFieldName(name, row, column), description, registry);
variables[row][column].setToNaN();
break;
}
case ROWS:
{
if (maxNumberOfColumns > 1)
throw new IllegalArgumentException(
"The YoMatrix must be a column vector if only row names are provided, else unique names cannot be generated.");

variables[row][column] = new YoDouble(getFieldName(name, rowNames[row], ""), description, registry);
variables[row][column].setToNaN();
break;
}
case ROWS_AND_COLUMNS:
{
variables[row][column] = new YoDouble(getFieldName(name, rowNames[row], columnNames[column]), description, registry);
variables[row][column].setToNaN();
break;
}
}
}
}
}

public static String getFieldName(String prefix, int row, int column)
{
return getFieldName(prefix, "_" + row, "_" + column);
}

public static String getFieldName(String prefix, String rowName, String columName)
{
return prefix + rowName + columName;
}

/**
* Enum used to determine what names have been provided to the YoMatrix.
*/
Expand Down Expand Up @@ -212,7 +243,7 @@ public void scale(double scale)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, unsafe_get(row, col) * scale);
unsafe_set(row, col, unsafe_get(row, col) * scale, false);
}
}
}
Expand All @@ -233,7 +264,7 @@ public void scale(double scale, DMatrix matrix)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, matrix.unsafe_get(row, col) * scale);
unsafe_set(row, col, matrix.unsafe_get(row, col) * scale, false);
}
}
}
Expand Down Expand Up @@ -283,7 +314,7 @@ public void add(double alpha, DMatrix a, double beta, DMatrix b)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, alpha * a.unsafe_get(row, col) + beta * b.unsafe_get(row, col));
unsafe_set(row, col, alpha * a.unsafe_get(row, col) + beta * b.unsafe_get(row, col), false);
}
}
}
Expand Down Expand Up @@ -314,7 +345,7 @@ public void addEquals(double alpha, DMatrix a)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, unsafe_get(row, col) + alpha * a.unsafe_get(row, col));
unsafe_set(row, col, unsafe_get(row, col) + alpha * a.unsafe_get(row, col), false);
}
}
}
Expand Down Expand Up @@ -398,15 +429,15 @@ else if (numRows < 0 || numCols < 0)
{
for (int col = numCols; col < maxNumberOfColumns; col++)
{
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}

for (int row = numRows; row < maxNumberOfRows; row++)
{
for (int col = 0; col < maxNumberOfColumns; col++)
{
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}
}
Expand All @@ -423,7 +454,7 @@ public void set(int row, int col, double val)
{
if (col < 0 || col >= getNumCols() || row < 0 || row >= getNumRows())
throw new IllegalArgumentException("Specified element is out of bounds: (" + row + " , " + col + ")");
unsafe_set(row, col, val);
unsafe_set(row, col, val, false);
}

@Override
Expand All @@ -432,6 +463,11 @@ public void unsafe_set(int row, int col, double val)
variables[row][col].set(val);
}

private void unsafe_set(int row, int col, double val, boolean notifyListeners)
{
variables[row][col].set(val, notifyListeners);
}

/**
* Set {@code this} to the matrix {@code original}.
* <p>
Expand All @@ -450,7 +486,7 @@ public void set(Matrix original)
{
for (int col = 0; col < getNumCols(); col++)
{
unsafe_set(row, col, otherMatrix.unsafe_get(row, col));
unsafe_set(row, col, otherMatrix.unsafe_get(row, col), false);
}
}
}
Expand All @@ -473,7 +509,7 @@ public void setToNaN(int numRows, int numCols)
{
for (int col = 0; col < numCols; col++)
{
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}
}
Expand Down Expand Up @@ -507,9 +543,9 @@ public void zero()
for (int col = 0; col < maxNumberOfColumns; col++)
{
if (row < getNumRows() && col < getNumCols())
unsafe_set(row, col, 0.0);
unsafe_set(row, col, 0.0, false);
else
unsafe_set(row, col, Double.NaN);
unsafe_set(row, col, Double.NaN, false);
}
}
}
Expand Down
186 changes: 182 additions & 4 deletions src/test/java/us/ihmc/yoVariables/math/YoMatrixTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package us.ihmc.yoVariables.math;

import org.ejml.EjmlUnitTests;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.RandomMatrices_DDRM;
Expand All @@ -16,6 +17,39 @@ public class YoMatrixTest
private static final double EPSILON = 1.0e-10;
private static final int ITERATIONS = 1000;

@Test
public void testSimpleYoMatrixExample()
{
int maxNumberOfRows = 4;
int maxNumberOfColumns = 8;
YoRegistry registry = new YoRegistry("testRegistry");
us.ihmc.yoVariables.filters.YoMatrix yoMatrix = new us.ihmc.yoVariables.filters.YoMatrix("testMatrix", maxNumberOfRows, maxNumberOfColumns, registry);
assertEquals(maxNumberOfRows, yoMatrix.getNumRows());
assertEquals(maxNumberOfColumns, yoMatrix.getNumCols());

DMatrixRMaj denseMatrix = new DMatrixRMaj(maxNumberOfRows, maxNumberOfColumns);
yoMatrix.reshape(maxNumberOfRows, maxNumberOfColumns);
yoMatrix.zero();
yoMatrix.get(denseMatrix);

DMatrixRMaj zeroMatrix = new DMatrixRMaj(maxNumberOfRows, maxNumberOfColumns);
EjmlUnitTests.assertEquals(zeroMatrix, denseMatrix, 1e-10);

Random random = new Random(1984L);

DMatrixRMaj randomMatrix = RandomMatrices_DDRM.rectangle(maxNumberOfRows, maxNumberOfColumns, random);
yoMatrix.set(randomMatrix);

DMatrixRMaj checkMatrix = new DMatrixRMaj(maxNumberOfRows, maxNumberOfColumns);
yoMatrix.get(checkMatrix);

EjmlUnitTests.assertEquals(randomMatrix, checkMatrix, 1e-10);

assertEquals(registry.findVariable(us.ihmc.yoVariables.filters.YoMatrix.getFieldName("testMatrix", 0, 0)).getValueAsDouble(),
checkMatrix.get(0, 0),
1e-10);
}

@Test
public void testSimpleYoMatrixRefactorExample()
{
Expand Down Expand Up @@ -187,28 +221,105 @@ public void testSetDimensioning()
}
}

@Test
public void testYoMatrixDimensioning()
{
int maxNumberOfRows = 4;
int maxNumberOfColumns = 8;
String name = "testMatrix";

YoRegistry registry = new YoRegistry("testRegistry");
us.ihmc.yoVariables.filters.YoMatrix yoMatrix = new us.ihmc.yoVariables.filters.YoMatrix(name, maxNumberOfRows, maxNumberOfColumns, registry);

int smallerRows = maxNumberOfRows - 2;
int smallerColumns = maxNumberOfColumns - 3;
DMatrixRMaj denseMatrix = new DMatrixRMaj(smallerRows, smallerColumns);

try
{
yoMatrix.get(denseMatrix);
fail("Should throw an exception if the size isn't right!");
}
catch (Exception e)
{
}

yoMatrix.reshape(maxNumberOfRows, maxNumberOfColumns);
yoMatrix.zero();
yoMatrix.getAndReshape(denseMatrix);
DMatrixRMaj zeroMatrix = new DMatrixRMaj(maxNumberOfRows, maxNumberOfColumns);
EjmlUnitTests.assertEquals(zeroMatrix, denseMatrix, 1e-10);
assertEquals(maxNumberOfRows, denseMatrix.getNumRows());
assertEquals(maxNumberOfColumns, denseMatrix.getNumCols());

assertEquals(maxNumberOfRows, yoMatrix.getNumRows());
assertEquals(maxNumberOfColumns, yoMatrix.getNumCols());

Random random = new Random(1984L);

DMatrixRMaj randomMatrix = RandomMatrices_DDRM.rectangle(maxNumberOfRows, maxNumberOfColumns, random);
yoMatrix.set(randomMatrix);

DMatrixRMaj checkMatrix = new DMatrixRMaj(maxNumberOfRows, maxNumberOfColumns);
yoMatrix.get(checkMatrix);

EjmlUnitTests.assertEquals(randomMatrix, checkMatrix, 1e-10);

DMatrixRMaj smallerMatrix = RandomMatrices_DDRM.rectangle(smallerRows, smallerColumns, random);
yoMatrix.set(smallerMatrix);

assertEquals(smallerRows, smallerMatrix.getNumRows());
assertEquals(smallerColumns, smallerMatrix.getNumCols());

assertEquals(smallerRows, yoMatrix.getNumRows());
assertEquals(smallerColumns, yoMatrix.getNumCols());

DMatrixRMaj checkMatrix2 = new DMatrixRMaj(1, 1);
yoMatrix.getAndReshape(checkMatrix2);

EjmlUnitTests.assertEquals(smallerMatrix, checkMatrix2, 1e-10);

checkMatrixYoVariablesEqualsCheckMatrixAndOutsideValuesAreNaN(name, maxNumberOfRows, maxNumberOfColumns, checkMatrix2, registry);
}

@Test
public void testZero()
{
Random random = new Random(1984L);

int maxNumberOfRows = 4;
int maxNumberOfColumns = 8;
String name = "testMatrixForZero";
YoRegistry registry = new YoRegistry("testRegistry");
YoMatrix yoMatrix = new YoMatrix(name, maxNumberOfRows, maxNumberOfColumns, registry);

DMatrixRMaj randomMatrix = RandomMatrices_DDRM.rectangle(maxNumberOfRows, maxNumberOfColumns, random);
yoMatrix.set(randomMatrix);

int numberOfRows = 2;
int numberOfColumns = 6;
yoMatrix.reshape(numberOfRows, numberOfColumns);
yoMatrix.zero();

DMatrixRMaj zeroMatrix = new DMatrixRMaj(numberOfRows, numberOfColumns);
checkMatrixYoVariablesEqualsCheckMatrixAndOutsideValuesAreNaN(name, maxNumberOfRows, maxNumberOfColumns, zeroMatrix, registry);

for (int i = 0; i < ITERATIONS; ++i)
{
int rowSize = random.nextInt(5, 10);
int columnSize = random.nextInt(5, 10);

String name = "testMatrixForZero";
YoRegistry registry = new YoRegistry("testRegistry");
registry = new YoRegistry("testRegistry");
YoMatrix matrix = new YoMatrix(name, rowSize, columnSize, registry);

DMatrixRMaj randomMatrix = RandomMatrices_DDRM.rectangle(rowSize, columnSize, random);
randomMatrix = RandomMatrices_DDRM.rectangle(rowSize, columnSize, random);
matrix.set(randomMatrix);

int smallerRowSize = random.nextInt(1, rowSize);
int smallerColumnSize = random.nextInt(1, columnSize);
matrix.zero(smallerRowSize, smallerColumnSize);

DMatrixRMaj zeroMatrix = new DMatrixRMaj(smallerRowSize, smallerColumnSize);
zeroMatrix = new DMatrixRMaj(smallerRowSize, smallerColumnSize);
checkMatrixYoVariablesEqualsCheckMatrixAndOutsideValuesAreNaN(name, rowSize, columnSize, zeroMatrix, registry);
}
}
Expand Down Expand Up @@ -537,6 +648,73 @@ public void testSafeSetAndGetWithIndices()
}
}

@Test
public void testYoMatrixSetTooBig()
{
int maxNumberOfRows = 4;
int maxNumberOfColumns = 8;
String name = "testMatrix";
YoRegistry registry = new YoRegistry("testRegistry");
us.ihmc.yoVariables.filters.YoMatrix yoMatrix = new us.ihmc.yoVariables.filters.YoMatrix(name, maxNumberOfRows, maxNumberOfColumns, registry);

DMatrixRMaj tooBigMatrix = new DMatrixRMaj(maxNumberOfRows + 1, maxNumberOfColumns);

try
{
yoMatrix.set(tooBigMatrix);
fail("Too Big");
}
catch (RuntimeException e)
{
}

tooBigMatrix = new DMatrixRMaj(maxNumberOfRows, maxNumberOfColumns + 1);

try
{
yoMatrix.set(tooBigMatrix);
fail("Too Big");
}
catch (RuntimeException e)
{
}

// Test a 0 X Big Matrix
DMatrixRMaj okMatrix = new DMatrixRMaj(0, maxNumberOfColumns + 10);
yoMatrix.set(okMatrix);
assertMatrixYoVariablesAreNaN(name, maxNumberOfRows, maxNumberOfColumns, registry);

DMatrixRMaj checkMatrix = new DMatrixRMaj(1, 1);
yoMatrix.getAndReshape(checkMatrix);

assertEquals(0, checkMatrix.getNumRows());
assertEquals(maxNumberOfColumns + 10, checkMatrix.getNumCols());

// Test a Big X 0 Matrix

okMatrix = new DMatrixRMaj(maxNumberOfRows + 10, 0);
yoMatrix.set(okMatrix);
assertMatrixYoVariablesAreNaN(name, maxNumberOfRows, maxNumberOfColumns, registry);

checkMatrix = new DMatrixRMaj(1, 1);
yoMatrix.getAndReshape(checkMatrix);

assertEquals(maxNumberOfRows + 10, checkMatrix.getNumRows());
assertEquals(0, checkMatrix.getNumCols());
}

private void assertMatrixYoVariablesAreNaN(String name, int maxNumberOfRows, int maxNumberOfColumns, YoRegistry registry)
{
for (int row = 0; row < maxNumberOfRows; row++)
{
for (int column = 0; column < maxNumberOfColumns; column++)
{
YoDouble variable = (YoDouble) registry.findVariable(us.ihmc.yoVariables.filters.YoMatrix.getFieldName(name, row, column));
assertTrue(Double.isNaN(variable.getDoubleValue()));
}
}
}

private void checkMatrixYoVariablesEqualsCheckMatrixAndOutsideValuesAreNaN(String name,
int maxNumberOfRows,
int maxNumberOfColumns,
Expand Down

0 comments on commit e034cc4

Please sign in to comment.