diff --git a/qlat/auto_contractor/compile.py b/qlat/auto_contractor/compile.py index 004a370a..472f0a36 100644 --- a/qlat/auto_contractor/compile.py +++ b/qlat/auto_contractor/compile.py @@ -1401,6 +1401,7 @@ class CExprCodeGenPy: self.indent self.total_sloppy_flops # + flops per complex addition: 2 flops per complex multiplication: 6 flops per matrix multiplication: 6 M N L + 2 M L (N-1) ==> 13536 (sc * sc), 4320 (sc * s), 480 (s * s), 3168 (sc * c), 198 (c * c) flops per trace 2 (M-1) ==> 22 (sc) @@ -1886,16 +1887,29 @@ def cexpr_function_bs_eval(self): append_cy(f"cdef cc.PyComplexD {name}_v = 0") append_py(f"{name} = 0") ch1, ch2, ch3, = chain_list + def get_ss_tag(ss): + return "_".join([ str(s) for s in ss ]) + ss_dict = dict() + for bs in bs_list: + elem_list = bs.get_spin_spin_tensor_elem_list_code() + for ss, coef, in elem_list: + if ss in ss_dict: + continue + ss_name = f"ss_res_{get_ss_tag(ss)}" + ss_dict[ss] = ss_name + v_s1, b_s1, v_s2, b_s2, v_s3, b_s3, = ss + append_cy(f"cdef cc.PyComplexD {ss_name} = cc.pycc_d(cc.epsilon_contraction({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}[0], {ch2}[0], {ch3}[0]))") + append_py(f"{ss_name} = mat_epsilon_contraction_wm_wm_wm({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}, {ch2}, {ch3})") + self.total_sloppy_flops += 504 # 6 * 6 * (2 * 6 + 2) for name, bs in zip(name_list, bs_list): elem_list = bs.get_spin_spin_tensor_elem_list_code() for ss, coef, in elem_list: + ss_name = f"ss_res_{get_ss_tag(ss)}" c, t, = self.gen_expr(coef) assert t == "V_a" - v_s1, b_s1, v_s2, b_s2, v_s3, b_s3, = ss - append_cy(f"v = cc.pycc_d(cc.epsilon_contraction({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}[0], {ch2}[0], {ch3}[0]))") - append_py(f"v = mat_epsilon_contraction_wm_wm_wm({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}, {ch2}, {ch3})") - append_cy(f"{name}_v += ({c}) * v") - append_py(f"{name} += ({c}) * v") + append(f"v = {c}") + append_cy(f"{name}_v += v * {ss_name}") + append_py(f"{name} += v * {ss_name}") for name in name_list: append_cy(f"{name}[0] = {name}_v") name_list_str = ", ".join(name_list)