Skip to content

Commit

Permalink
fix benchmark auto contract
Browse files Browse the repository at this point in the history
  • Loading branch information
jinluchang committed Dec 25, 2024
1 parent 8eabbee commit 18a197b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions qlat/auto_contractor/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def benchmark_eval_cexpr_run_with_ama():
res1 = eval_cexpr(cexpr, positions_dict=positions_dict_list[k], get_prop=get_prop_ama)
res2 = eval_cexpr(cexpr, positions_dict=positions_dict_list[k], get_prop=get_prop)
res_ama, res_sloppy = eval_cexpr(cexpr, positions_dict=positions_dict_list[k], get_prop=get_prop_ama, is_ama_and_sloppy=True)
assert q.qnorm(res1 - res_ama) == 0
assert q.qnorm(res2 - res_sloppy) == 0
assert np.all(res1 == res_ama)
assert np.all(res2 == res_sloppy)
res_list.append(res_ama)
res = np.array(res_list)
assert res.shape == (benchmark_size, n_expr,)
Expand All @@ -312,6 +312,12 @@ def mk_check_vector(k):
return res
check_vector_list = [ mk_check_vector(k) for k in range(3) ]
def check_res(res):
if res.dtype != np.complex128:
rs_real = benchmark_rng_state.split(f"get_data_sig-real")
rs_imag = benchmark_rng_state.split(f"get_data_sig-imag")
resc = np.zeros_like(res, dtype=np.complex128)
resc.ravel()[:] = [ q.get_data_sig(v, rs_real) + 1j * q.get_data_sig(v, rs_imag) for v in res.ravel() ]
res = resc
return [ np.tensordot(res, cv).item() for cv in check_vector_list ]
q.displayln_info(f"benchmark_eval_cexpr: benchmark_size={benchmark_size}")
q.timer_fork(0)
Expand Down

0 comments on commit 18a197b

Please sign in to comment.