-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
94 lines (71 loc) · 2.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import time
from frozenlake import FrozenLake
from linearQlearning import linear_q_learning
from linearSarsa import linear_sarsa
from policyIteration import policy_iteration
from qLearning import q_learning
from sarsa import sarsa
from valueIteration import value_iteration
from LinearWrapper import LinearWrapper
def main():
seed = 0
# Small lake
lake = [['&', '.', '.', '.'],
['.', '#', '.', '#'],
['.', '.', '.', '#'],
['#', '.', '.', '$']]
# big_lake = [['&', '.', '.', '.', '.', '.', '.', '.'],
# ['.', '.', '.', '.', '.', '.', '.', '.'],
# ['.', '.', '.', '#', '.', '.', '.', '.'],
# ['.', '.', '.', '.', '.', '#', '.', '.'],
# ['.', '.', '.', '#', '.', '.', '.', '.'],
# ['.', '#', '#', '.', '.', '.', '#', '.'],
# ['.', '#', '.', '.', '#', '.', '#', '.'],
# ['.', '.', '.', '#', '.', '.', '.', '$']]
env = FrozenLake(lake, slip=0.1, max_steps=16, seed=seed)
print('# Model-based algorithms')
gamma = 0.9
theta = 0.001
max_iterations = 100
print('')
#start = time.time()
print('## Policy iteration')
policy, value = policy_iteration(env, gamma, theta, max_iterations)
#end = time.time()
env.render(policy, value)
#print("TIme taken by policy iteration is:- ,",end-start)
print('')
#start = time.time()
print('## Value iteration')
policy, value = value_iteration(env, gamma, theta, max_iterations)
#end = time.time()
env.render(policy, value)
#print("TIme taken by value iteration is:- ,",end-start)
print('')
print('# Model-free algorithms')
max_episodes = 2000
eta = 0.5
epsilon = 0.5
print('')
print('## Sarsa')
policy, value = sarsa(env, max_episodes, eta, gamma, epsilon, seed=seed)
env.render(policy, value)
print('')
print('## Q-learning')
policy, value = q_learning(env, max_episodes, eta, gamma, epsilon, seed=seed)
env.render(policy, value)
print('')
linear_env = LinearWrapper(env)
print('## Linear Sarsa')
parameters = linear_sarsa(linear_env, max_episodes, eta,
gamma, epsilon, seed=seed)
policy, value = linear_env.decode_policy(parameters)
linear_env.render(policy, value)
print('')
print('## Linear Q-learning')
parameters = linear_q_learning(linear_env, max_episodes, eta,
gamma, epsilon, seed=seed)
policy, value = linear_env.decode_policy(parameters)
linear_env.render(policy, value)
if __name__ == '__main__':
main()