From 1234ac6d37cfcdcdee7bc445f5b27c3c7fee8044 Mon Sep 17 00:00:00 2001 From: Luchang Jin Date: Tue, 7 Jan 2025 02:12:13 -0500 Subject: [PATCH] improve common sum --- qlat/auto_contractor/compile.py | 61 +++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/qlat/auto_contractor/compile.py b/qlat/auto_contractor/compile.py index 30620577..63c872f1 100644 --- a/qlat/auto_contractor/compile.py +++ b/qlat/auto_contractor/compile.py @@ -259,7 +259,7 @@ def collect_common_prod_in_factors(variables_factor_intermediate, variables_fact @q.timer def find_common_sum_in_factors(variables_factor): """ - return None or (var_name_1, var_name_2,) + return [ (var_name_1, var_name_2,), ... ] """ subexpr_count = {} for _, ea_coef in variables_factor: @@ -281,25 +281,32 @@ def find_common_sum_in_factors(variables_factor): count = subexpr_count.get(pair, 0) subexpr_count[pair] = count + 1 max_num_repeat = 0 - best_match = None + best_match_list = [] for pair, num_repeat in subexpr_count.items(): if num_repeat > max_num_repeat: max_num_repeat = num_repeat - best_match = pair - return best_match + best_match_list = [ pair, ] + elif num_repeat == max_num_repeat: + best_match_list.append(pair) + return best_match_list @q.timer -def collect_common_sum_in_factors(variables_factor, common_pair, var): +def collect_common_sum_in_factors(variables_factor_intermediate, variables_factor, var_nameset, var_counter, common_pair_list, var_dataset): """ - common_pair = find_common_sum_in_factors(variables_factor) + common_pair_list = find_common_sum_in_factors(variables_factor) + var = var_dataset[pair] var = ea.Factor(name, variables=[], otype="Var") """ + common_pair_set = set(common_pair_list) for _, ea_coef in variables_factor: assert isinstance(ea_coef, ea.Expr) x = ea_coef.terms - if len(x) < 1: + if len(x) <= 1: continue - for i, t in enumerate(x[:-1]): + for i in range(len(x) - 1): + t = x[i] + if t is None: + continue t1 = x[i+1] assert t.coef == 1 assert len(t.factors) == 1 @@ -310,7 +317,23 @@ def collect_common_sum_in_factors(variables_factor, common_pair, var): assert f.otype == "Var" assert f1.otype == "Var" pair = (f.code, f1.code,) - if pair == common_pair: + if pair in common_pair_set: + if pair in var_dataset: + var = var_dataset[pair] + else: + code1, code2 = pair + var1 = ea.Factor(code1, variables=[], otype="Var") + var2 = ea.Factor(code2, variables=[], otype="Var") + pair_expr = ea.mk_expr(var1) + ea.mk_expr(var2) + while True: + name = f"V_factor_sum_{var_counter}" + var_counter += 1 + if name not in var_nameset: + break + var_nameset.add(name) + variables_factor_intermediate.append((name, pair_expr,)) + var = ea.Factor(name, variables=[], otype="Var") + var_dataset[pair] = var x[i].factors[0] = var x[i+1] = None for _, ea_coef in variables_factor: @@ -321,6 +344,7 @@ def collect_common_sum_in_factors(variables_factor, common_pair, var): if v is not None: x_new.append(v) ea_coef.terms = x_new + return var_counter @q.timer def collect_factor_in_cexpr(variables_factor, var_nameset, named_exprs, named_terms): @@ -473,23 +497,10 @@ def collect_factor_sum_in_cexpr(variables_factor_intermediate, var_nameset, vari var_counter = 0 var_dataset = {} # var_dataset[(code1, code2,)] = factor_var while True: - pair = find_common_sum_in_factors(variables_factor) - if pair is None: + pair_list = find_common_sum_in_factors(variables_factor) + if len(pair_list) == 0: break - code1, code2 = pair - var1 = ea.Factor(code1, variables=[], otype="Var") - var2 = ea.Factor(code2, variables=[], otype="Var") - pair_expr = ea.mk_expr(var1) + ea.mk_expr(var2) - while True: - name = f"V_factor_sum_{var_counter}" - var_counter += 1 - if name not in var_nameset: - break - var_nameset.add(name) - variables_factor_intermediate.append((name, pair_expr,)) - var = ea.Factor(name, variables=[], otype="Var") - var_dataset[pair] = var - collect_common_sum_in_factors(variables_factor, pair, var) + var_counter = collect_common_sum_in_factors(variables_factor_intermediate, variables_factor, var_nameset, var_counter, pair_list, var_dataset) @q.timer def collect_and_optimize_factor_in_cexpr(named_exprs, named_terms):