Skip to content

Commit

Permalink
improve common sum
Browse files Browse the repository at this point in the history
  • Loading branch information
jinluchang committed Jan 7, 2025
1 parent 56f3e0e commit 1234ac6
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions qlat/auto_contractor/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def collect_common_prod_in_factors(variables_factor_intermediate, variables_fact
@q.timer
def find_common_sum_in_factors(variables_factor):
"""
return None or (var_name_1, var_name_2,)
return [ (var_name_1, var_name_2,), ... ]
"""
subexpr_count = {}
for _, ea_coef in variables_factor:
Expand All @@ -281,25 +281,32 @@ def find_common_sum_in_factors(variables_factor):
count = subexpr_count.get(pair, 0)
subexpr_count[pair] = count + 1
max_num_repeat = 0
best_match = None
best_match_list = []
for pair, num_repeat in subexpr_count.items():
if num_repeat > max_num_repeat:
max_num_repeat = num_repeat
best_match = pair
return best_match
best_match_list = [ pair, ]
elif num_repeat == max_num_repeat:
best_match_list.append(pair)
return best_match_list

@q.timer
def collect_common_sum_in_factors(variables_factor, common_pair, var):
def collect_common_sum_in_factors(variables_factor_intermediate, variables_factor, var_nameset, var_counter, common_pair_list, var_dataset):
"""
common_pair = find_common_sum_in_factors(variables_factor)
common_pair_list = find_common_sum_in_factors(variables_factor)
var = var_dataset[pair]
var = ea.Factor(name, variables=[], otype="Var")
"""
common_pair_set = set(common_pair_list)
for _, ea_coef in variables_factor:
assert isinstance(ea_coef, ea.Expr)
x = ea_coef.terms
if len(x) < 1:
if len(x) <= 1:
continue
for i, t in enumerate(x[:-1]):
for i in range(len(x) - 1):
t = x[i]
if t is None:
continue
t1 = x[i+1]
assert t.coef == 1
assert len(t.factors) == 1
Expand All @@ -310,7 +317,23 @@ def collect_common_sum_in_factors(variables_factor, common_pair, var):
assert f.otype == "Var"
assert f1.otype == "Var"
pair = (f.code, f1.code,)
if pair == common_pair:
if pair in common_pair_set:
if pair in var_dataset:
var = var_dataset[pair]
else:
code1, code2 = pair
var1 = ea.Factor(code1, variables=[], otype="Var")
var2 = ea.Factor(code2, variables=[], otype="Var")
pair_expr = ea.mk_expr(var1) + ea.mk_expr(var2)
while True:
name = f"V_factor_sum_{var_counter}"
var_counter += 1
if name not in var_nameset:
break
var_nameset.add(name)
variables_factor_intermediate.append((name, pair_expr,))
var = ea.Factor(name, variables=[], otype="Var")
var_dataset[pair] = var
x[i].factors[0] = var
x[i+1] = None
for _, ea_coef in variables_factor:
Expand All @@ -321,6 +344,7 @@ def collect_common_sum_in_factors(variables_factor, common_pair, var):
if v is not None:
x_new.append(v)
ea_coef.terms = x_new
return var_counter

@q.timer
def collect_factor_in_cexpr(variables_factor, var_nameset, named_exprs, named_terms):
Expand Down Expand Up @@ -473,23 +497,10 @@ def collect_factor_sum_in_cexpr(variables_factor_intermediate, var_nameset, vari
var_counter = 0
var_dataset = {} # var_dataset[(code1, code2,)] = factor_var
while True:
pair = find_common_sum_in_factors(variables_factor)
if pair is None:
pair_list = find_common_sum_in_factors(variables_factor)
if len(pair_list) == 0:
break
code1, code2 = pair
var1 = ea.Factor(code1, variables=[], otype="Var")
var2 = ea.Factor(code2, variables=[], otype="Var")
pair_expr = ea.mk_expr(var1) + ea.mk_expr(var2)
while True:
name = f"V_factor_sum_{var_counter}"
var_counter += 1
if name not in var_nameset:
break
var_nameset.add(name)
variables_factor_intermediate.append((name, pair_expr,))
var = ea.Factor(name, variables=[], otype="Var")
var_dataset[pair] = var
collect_common_sum_in_factors(variables_factor, pair, var)
var_counter = collect_common_sum_in_factors(variables_factor_intermediate, variables_factor, var_nameset, var_counter, pair_list, var_dataset)

@q.timer
def collect_and_optimize_factor_in_cexpr(named_exprs, named_terms):
Expand Down

0 comments on commit 1234ac6

Please sign in to comment.