Skip to content

Commit

Permalink
optimize bs contraction
Browse files Browse the repository at this point in the history
  • Loading branch information
jinluchang committed Jan 1, 2025
1 parent 3c1a5b1 commit ff053e8
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions qlat/auto_contractor/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,7 @@ class CExprCodeGenPy:
self.indent
self.total_sloppy_flops
#
flops per complex addition: 2
flops per complex multiplication: 6
flops per matrix multiplication: 6 M N L + 2 M L (N-1) ==> 13536 (sc * sc), 4320 (sc * s), 480 (s * s), 3168 (sc * c), 198 (c * c)
flops per trace 2 (M-1) ==> 22 (sc)
Expand Down Expand Up @@ -1886,16 +1887,29 @@ def cexpr_function_bs_eval(self):
append_cy(f"cdef cc.PyComplexD {name}_v = 0")
append_py(f"{name} = 0")
ch1, ch2, ch3, = chain_list
def get_ss_tag(ss):
return "_".join([ str(s) for s in ss ])
ss_dict = dict()
for bs in bs_list:
elem_list = bs.get_spin_spin_tensor_elem_list_code()
for ss, coef, in elem_list:
if ss in ss_dict:
continue
ss_name = f"ss_res_{get_ss_tag(ss)}"
ss_dict[ss] = ss_name
v_s1, b_s1, v_s2, b_s2, v_s3, b_s3, = ss
append_cy(f"cdef cc.PyComplexD {ss_name} = cc.pycc_d(cc.epsilon_contraction({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}[0], {ch2}[0], {ch3}[0]))")
append_py(f"{ss_name} = mat_epsilon_contraction_wm_wm_wm({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}, {ch2}, {ch3})")
self.total_sloppy_flops += 504 # 6 * 6 * (2 * 6 + 2)
for name, bs in zip(name_list, bs_list):
elem_list = bs.get_spin_spin_tensor_elem_list_code()
for ss, coef, in elem_list:
ss_name = f"ss_res_{get_ss_tag(ss)}"
c, t, = self.gen_expr(coef)
assert t == "V_a"
v_s1, b_s1, v_s2, b_s2, v_s3, b_s3, = ss
append_cy(f"v = cc.pycc_d(cc.epsilon_contraction({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}[0], {ch2}[0], {ch3}[0]))")
append_py(f"v = mat_epsilon_contraction_wm_wm_wm({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}, {ch2}, {ch3})")
append_cy(f"{name}_v += ({c}) * v")
append_py(f"{name} += ({c}) * v")
append(f"v = {c}")
append_cy(f"{name}_v += v * {ss_name}")
append_py(f"{name} += v * {ss_name}")
for name in name_list:
append_cy(f"{name}[0] = {name}_v")
name_list_str = ", ".join(name_list)
Expand Down

0 comments on commit ff053e8

Please sign in to comment.