diff --git a/.github/workflows/develop-test.yml b/.github/workflows/develop-test.yml index 7ffd6172..7f5b759e 100644 --- a/.github/workflows/develop-test.yml +++ b/.github/workflows/develop-test.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10"] test-group: [short, optimize, multivariate, optimize-experimental] steps: - uses: actions/checkout@v2 @@ -60,7 +60,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -80,7 +80,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -102,7 +102,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -125,7 +125,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8"] + python-version: ["3.10"] steps: - uses: actions/checkout@v2 diff --git a/pyproject.toml b/pyproject.toml index 2c1e83a9..931250c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ tracker = "https://github.com/LLNL/MuyGPyS/issues" tests = [ "absl-py>=0.13.0", "matplotlib>=3.2.1", - "pandas==1.5.2", + "pandas>=2.2.2", ] dev = [ "black>=21.1.0", diff --git a/src/MuyGPyS/_src/gp/muygps/numpy.py b/src/MuyGPyS/_src/gp/muygps/numpy.py index cf3e6b16..93901d3d 100644 --- a/src/MuyGPyS/_src/gp/muygps/numpy.py +++ b/src/MuyGPyS/_src/gp/muygps/numpy.py @@ -90,4 +90,6 @@ def _muygps_fast_posterior_mean_precompute( train_nn_targets_fast: np.ndarray, **kwargs, ) -> np.ndarray: - return np.linalg.solve(Kin, train_nn_targets_fast) + if train_nn_targets_fast.ndim == 2: + train_nn_targets_fast = train_nn_targets_fast[:, :, None] + return np.squeeze(np.linalg.solve(Kin, train_nn_targets_fast))