-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaoc2216.py
144 lines (127 loc) Β· 3.98 KB
/
aoc2216.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import re
from functools import lru_cache
lines = open(0).read().splitlines()
flos = {}
tuns = {}
mains = []
M = {} # p2
for line in lines:
flo = int(re.findall(r'\d+', line)[0])
vvs = re.findall(r'[A-Z]{2}', line)
main = vvs.pop(0)
assert vvs
print('parsing/',main,flo,'\tremain',vvs)
flos[main] = flo
tuns[main] = vvs
mains.append(main)
M[main] = (flo, vvs)
for main in mains: print(main,'- flow/',flos[main],'\tβ',tuns[main])
# graphviz
f = open('16.x', 'w')
f.write('strict digraph {\n')
for m in mains:
for t in tuns[m]: f.write(f' {m} -> {t} [label={flos[m]}]\n')
f.write('}\n')
# part 2, also works for p1
not0 = 0
N = len(M)
indexofmain = {}
sortedmains = []
# 'AA' added first
for main in M:
if main == 'AA':
indexofmain[main] = len(sortedmains)
sortedmains.append(main)
not0 += 1
# then other substantial vs
for main, (flo, _) in M.items():
if flo > 0:
indexofmain[main] = len(sortedmains)
sortedmains.append(main)
not0 += 1
# 0-flow valve
for main in M:
if main in sortedmains: continue
indexofmain[main] = len(sortedmains)
sortedmains.append(main)
# tunnels represented as sorted indices mapped to sorted valves indices
sortedtunns = [ [] for _ in range(N) ]
for i in range(N):
for main in M[sortedmains[i]][1]:
sortedtunns[i].append( indexofmain[main] )
# rates represented as sorted indices array
sortedrates = [ M[main][0] for main in sortedmains ]
INF = -10**20
DP = [(INF,None)] * ( (1 << not0) * N * 31 * 2 )
print(len(DP),'dp/len')
def DFSp2(remain, currentvv, openedSet, player):
if remain == 0:
if player == 0:
return 0, openedSet
return DFSp2(26,0,openedSet,player-1)
main = openedSet*N*31*2 + currentvv*31*2 + remain*2 + player
if DP[main][0] > -1:
return DP[main]
res = 0
pathmask = openedSet
currentmask = ( not (openedSet & (1 << currentvv)))
if currentmask and sortedrates[currentvv]:# > 0:
updated_opened = openedSet | (1 << currentvv)
temp, pm = DFSp2(remain-1, currentvv, updated_opened, player)
_next = (remain - 1) * sortedrates[currentvv] + temp
if res < _next:
res = _next
pathmask = pm
for nextvv in sortedtunns[currentvv]:
_next, pm = DFSp2(remain-1, nextvv, openedSet, player)
if res < _next:
res = _next
pathmask = pm
DP[main] = (res,pathmask)
return res, pathmask
# part 1
states = {}
def DFS (remain, currentvv, openedSet):
if remain == 0:
return 0, set(openedSet)
state = (remain, currentvv, frozenset(openedSet))
if state in states:
return states[state]
if remain < 0:
print('state/',state,'remaining time/',remain)
assert False
res = 0
bestpath = set(openedSet)
currentsum = sum(flos[_] for _ in openedSet)
# option/1 - open current valve
if currentvv not in openedSet and flos[currentvv] > 0:
nextsum, nextpath = DFS(remain - 1, currentvv, openedSet | { currentvv })
totalsum = currentsum + nextsum
if res < totalsum:
bestpath = nextpath | { currentvv }
res = totalsum
# option/2 - enter the tunnel, aiming at a valve beyond, not opening it
for nextvv in tuns[currentvv]:
nextsum, nextpath = DFS(remain - 1, nextvv, openedSet)
totalsum = currentsum + nextsum
if res < totalsum:
bestpath = nextpath
res = totalsum
states[state] = (res, bestpath)
return res, bestpath
#p1,path = DFS(30, 'AA', set())
p1,pm1 = DFSp2(30,0,0,0)
p2,pm2 = DFSp2(26,0,0,1)
def getpath(pathmask):
path = []
for i in range(len(sortedmains)):
if pathmask & (1 << i): path.append(sortedmains[i])
return path
path = getpath(pm1)
print('part 1:',p1, '\tβ',','.join(sorted(list(path))))
assert p1 in [1651,2183]
path = getpath(pm2)
print('part 2:',p2, '\tβ',','.join(sorted(list(path))))
assert p2 in [1707,2911]
# p2/lo - 1636
# p2/hi - 4642