diff --git a/tests/python/linalg/double_factorized_decomposition_test.py b/tests/python/linalg/double_factorized_decomposition_test.py index 06f1c2007..37828a548 100644 --- a/tests/python/linalg/double_factorized_decomposition_test.py +++ b/tests/python/linalg/double_factorized_decomposition_test.py @@ -396,3 +396,52 @@ def test_double_factorized_t2_alpha_beta_random(): orbital_rotations[:, 0, 0], orbital_rotations[:, 3, 0].conj(), atol=1e-8 ) # TODO add the rest of the relations + + +def test_double_factorized_t2_alpha_beta_tol_max_vecs(): + """Test double-factorized decomposition alpha-beta error threshold and max vecs.""" + mol = gto.Mole() + mol.build( + atom=[["H", (0, 0, 0)], ["O", (0, 0, 1.1)]], + basis="6-31g", + spin=1, + symmetry="Coov", + ) + hartree_fock = scf.ROHF(mol).run() + + ccsd = cc.CCSD(hartree_fock).run() + _, t2ab, _ = ccsd.t2 + nocc_a, nocc_b, nvrt_a, _ = t2ab.shape + norb = nocc_a + nvrt_a + + # test max_vecs + max_vecs = 25 + diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2_alpha_beta( + t2ab, max_vecs=max_vecs + ) + reconstructed = reconstruct_t2_alpha_beta( + diag_coulomb_mats, orbital_rotations, norb=norb, nocc_a=nocc_a, nocc_b=nocc_b + ) + assert len(diag_coulomb_mats) == max_vecs + np.testing.assert_allclose(reconstructed, t2ab, atol=1e-4) + + # test error threshold + tol = 1e-3 + diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2_alpha_beta( + t2ab, tol=tol + ) + reconstructed = reconstruct_t2_alpha_beta( + diag_coulomb_mats, orbital_rotations, norb=norb, nocc_a=nocc_a, nocc_b=nocc_b + ) + assert len(diag_coulomb_mats) <= 23 + np.testing.assert_allclose(reconstructed, t2ab, atol=tol) + + # test error threshold and max vecs + diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2_alpha_beta( + t2ab, tol=tol, max_vecs=max_vecs + ) + reconstructed = reconstruct_t2_alpha_beta( + diag_coulomb_mats, orbital_rotations, norb=norb, nocc_a=nocc_a, nocc_b=nocc_b + ) + assert len(orbital_rotations) <= 23 + np.testing.assert_allclose(reconstructed, t2ab, atol=tol)