diff --git a/qlat/auto_contractor/eval.py b/qlat/auto_contractor/eval.py index 1db08c3c..0521ab98 100644 --- a/qlat/auto_contractor/eval.py +++ b/qlat/auto_contractor/eval.py @@ -298,8 +298,8 @@ def benchmark_eval_cexpr_run_with_ama(): res1 = eval_cexpr(cexpr, positions_dict=positions_dict_list[k], get_prop=get_prop_ama) res2 = eval_cexpr(cexpr, positions_dict=positions_dict_list[k], get_prop=get_prop) res_ama, res_sloppy = eval_cexpr(cexpr, positions_dict=positions_dict_list[k], get_prop=get_prop_ama, is_ama_and_sloppy=True) - assert q.qnorm(res1 - res_ama) == 0 - assert q.qnorm(res2 - res_sloppy) == 0 + assert np.all(res1 == res_ama) + assert np.all(res2 == res_sloppy) res_list.append(res_ama) res = np.array(res_list) assert res.shape == (benchmark_size, n_expr,) @@ -312,6 +312,12 @@ def mk_check_vector(k): return res check_vector_list = [ mk_check_vector(k) for k in range(3) ] def check_res(res): + if res.dtype != np.complex128: + rs_real = benchmark_rng_state.split(f"get_data_sig-real") + rs_imag = benchmark_rng_state.split(f"get_data_sig-imag") + resc = np.zeros_like(res, dtype=np.complex128) + resc.ravel()[:] = [ q.get_data_sig(v, rs_real) + 1j * q.get_data_sig(v, rs_imag) for v in res.ravel() ] + res = resc return [ np.tensordot(res, cv).item() for cv in check_vector_list ] q.displayln_info(f"benchmark_eval_cexpr: benchmark_size={benchmark_size}") q.timer_fork(0)