forked from NathanKlineInstitute/SMARTAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimplePong.py
257 lines (237 loc) · 11.7 KB
/
simplePong.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import numpy as np
from conf import dconf
import random
from matplotlib import pyplot as plt
class simplePong:
def __init__ (self, seed=1234):
random.seed(seed)
self.createcourt()
self.obs = np.zeros(shape=(210,160,3)) # this is the image (observation) frame
self.createnewframe()
self.createball() # create ball
self.createrackets() # create rackets
# by default no reward
self.reward =0
# points
self.TotalHits = 0
self.TotalMissed = 0
self.MissedTheBall = 0
self.NewServe = 0
self.scoreRecorded = 0
self.createFigure()
def createFigure (self):
self.fig,self.ax = plt.subplots(1,1)
self.im = self.ax.imshow(np.zeros(shape=(210,160,3)))
self.scorestr = self.ax.text(1, 20, 'M,H:0,0', style='normal', color='lightgreen', size=28)
def createcourt (self):
self.court_top = 34
self.court_bottom = 194
self.court_redge = 159
self.court_ledge = 0
def createball (self):
# ball position
self.ypos_ball = dconf['simulatedEnvParams']['yball'] # this corresponds to 0 index
self.xpos_ball = 0 # this corresponds to 1 index
# start ball from the middle
self.randomizeYpos = dconf['simulatedEnvParams']['random']
self.wiggle = dconf['wiggle']
self.ball_width = 4
self.ball_height = 4
self.ballx1 = self.xpos_ball
self.ballx2 = self.xpos_ball+self.ball_width
self.bally1 = self.court_top+self.ypos_ball
self.bally2 = self.court_top+self.ypos_ball+self.ball_height
self.obs[self.bally1:self.bally2,self.ballx1:self.ballx2,0]=236
self.obs[self.bally1:self.bally2,self.ballx1:self.ballx2,1]=236
self.obs[self.bally1:self.bally2,self.ballx1:self.ballx2,2]=236
# create ball speed or displacement
self.ball_dx = 1 # displacement in horizontal direction
self.ball_dy = 1 #displacement in vertical direction
self.possible_ball_ypos = dconf['simulatedEnvParams']['possible_ball_ypos']
self.possible_ball_dy = dconf['simulatedEnvParams']['possible_ball_dy']
self.possible_ball_dx = dconf['simulatedEnvParams']['possible_ball_dx']
def createrackets (self):
self.racket_width = 4
self.racket_height = 16
# racket positions
self.xpos_racket = 140 # this is fixed
self.ypos_racket = dconf['simulatedEnvParams']['yracket'] # this can change
# right racket
self.rightracketx1 = self.xpos_racket
self.rightracketx2 = self.xpos_racket+self.racket_width
self.rightrackety1 = self.court_top+self.ypos_racket
self.rightrackety2 = self.court_top+self.ypos_racket+self.racket_height
self.obs[self.rightrackety1:self.rightrackety2,self.rightracketx1:self.rightracketx2,0]= 92
self.obs[self.rightrackety1:self.rightrackety2,self.rightracketx1:self.rightracketx2,1]= 186
self.obs[self.rightrackety1:self.rightrackety2,self.rightracketx1:self.rightracketx2,2]= 92
# create racket speed or displacement
self.racket_dy = dconf['simulatedEnvParams']['racket_dy'] # displacement of rackets.
def createnewframe (self):
self.obs.fill(0)
self.obs[self.court_top:self.court_bottom,:,0]=144
self.obs[self.court_top:self.court_bottom,:,1]=72
self.obs[self.court_top:self.court_bottom,:,2]=17
def moveball (self,xshift_ball,yshift_ball):
self.ballx1 += xshift_ball
self.ballx2 += xshift_ball
self.bally1 += yshift_ball
self.bally2 += yshift_ball
self.obs[self.bally1:self.bally2,self.ballx1:self.ballx2,0]=236
self.obs[self.bally1:self.bally2,self.ballx1:self.ballx2,1]=236
self.obs[self.bally1:self.bally2,self.ballx1:self.ballx2,2]=236
def moveracket (self,yshift_racket):
self.rightrackety1 += yshift_racket
self.rightrackety2 += yshift_racket
if self.rightrackety1 > self.court_bottom - self.racket_height:
self.rightrackety1 -= yshift_racket
self.rightrackety2 -= yshift_racket
if self.rightrackety2 < self.court_top + self.racket_height:
self.rightrackety1 -= yshift_racket
self.rightrackety2 -= yshift_racket
self.obs[self.rightrackety1:self.rightrackety2,self.rightracketx1:self.rightracketx2,0]= 92
self.obs[self.rightrackety1:self.rightrackety2,self.rightracketx1:self.rightracketx2,1]= 186
self.obs[self.rightrackety1:self.rightrackety2,self.rightracketx1:self.rightracketx2,2]= 92
# xshift_ball, yshift_ball = getNextBallShift()
def getNextBallShift (self, right_racket_yshift):
# ball position is defined by self.b1x, self.b2x, self.b1y and self.b2y
# right racket position is defined by self.r1y, self.r2y, self.r1x and self.r2x. Both self.r1x and self.r2x are fixed.
# court coordinates are self.court_top = 34, self.court_bottom = 194, self.court_ledge = 0, self.court_redge = 159
# direction can be checked by looking at the sign of self.ball_dx and self.ball_dy
# 1. make a temp move
self.reward = 0
if self.ballx2<self.rightracketx1 and self.NewServe:
# why need to set missedball==0 here? can't you set it after it's set to 1 at end of this function?
# and why does missedtheball need to persist outside of this function - it's only used within this function
# answer: probably because ball can keep moving beyond rackets and do not want to count score twice ...
self.MissedTheBall = 0
self.NewServe = 0
self.scoreRecorded = 0
tmp_ballx1 = self.ballx1 + self.ball_dx
tmp_ballx2 = self.ballx2 + self.ball_dx
tmp_bally1 = self.bally1 + self.ball_dy
tmp_bally2 = self.bally2 + self.ball_dy
xshift_ball = self.ball_dx
yshift_ball = self.ball_dy
# check if the ball hits the left edge
if self.ball_dx<0: # moving leftwards
if tmp_ballx1<0: # if hit the bottom of the court, bounces back
xshift_ball = self.ball_dx - tmp_ballx1
self.ball_dy = np.sign(self.ball_dy) * abs(random.choice(self.possible_ball_dy))
y_shift_ball = self.ball_dy
self.ball_dx *= -1
# 2. check if the ball hits upper edge or lower edge
if self.ball_dy>0: # moving downwards
if tmp_bally2>=self.court_bottom: # if hit the bottom of the court, bounces back
yshift_ball = self.ball_dy + self.court_bottom - tmp_bally2
tmp_bally1 = self.bally1 + yshift_ball
tmp_bally2 = self.bally2 + yshift_ball
self.ball_dy = -1*self.ball_dy
elif self.ball_dy<0: # moving upwards
if tmp_bally1<=self.court_top: #if hit the top of the court, bounces back
yshift_ball = self.ball_dy - tmp_bally1 + self.court_top
tmp_bally1 = self.bally1 + yshift_ball
tmp_bally2 = self.bally2 + yshift_ball
self.ball_dy = -1*self.ball_dy
else:
yshift_ball = self.ball_dy
# 4. check if the ball hits the racket
# when ball moving towards the (right) racket controlled by the model
if self.ball_dx>0 and tmp_ballx2>=self.rightracketx1 and tmp_ballx2<=self.court_redge and self.MissedTheBall==0:
if (tmp_bally1>=self.rightrackety1 and tmp_bally1<=self.rightrackety2) or (tmp_bally2>=self.rightrackety1 and tmp_bally2<=self.rightrackety2):
# if upper or lower edge of the ball is within the range of the racket
xshift_ball = self.ball_dx + self.rightracketx1-tmp_ballx2
self.ball_dy = np.sign(self.ball_dy) * abs(random.choice(self.possible_ball_dy))
y_shift_ball = self.ball_dy
self.ball_dx *= -1
self.TotalHits += 1
self.reward = 1
elif dconf['simulatedEnvParams']['top_bottom_rule'] and right_racket_yshift < 0 and abs(tmp_bally2 - self.rightrackety1) <= 2:
print('hit top R')
xshift_ball = self.ball_dx + self.rightracketx1 - tmp_ballx2
self.ball_dy = np.sign(self.ball_dy) * abs(random.choice(self.possible_ball_dy)) * 2
y_shift_ball = self.ball_dy
self.ball_dx *= -1
self.TotalHits += 1
self.reward = 1
elif dconf['simulatedEnvParams']['top_bottom_rule'] and right_racket_yshift > 0 and abs(tmp_bally1 - self.rightrackety2) <= 2:
print('hit bottom R')
xshift_ball = self.ball_dx + self.rightracketx1 - tmp_ballx2
self.ball_dy = np.sign(self.ball_dy) * abs(random.choice(self.possible_ball_dy)) * 2
y_shift_ball = self.ball_dy
self.ball_dx *= -1
self.TotalHits += 1
self.reward = 1
else:
if self.scoreRecorded==0 and tmp_ballx1>self.rightracketx2:
self.TotalMissed += 1
self.MissedTheBall = 1
if not dconf['simulatedEnvParams']['dodraw']:
print('Player missed the ball')
print('Hits: ', self.TotalHits, 'Missed: ',self.TotalMissed)
print('Ball (projected):',tmp_ballx1,tmp_ballx2,tmp_bally1,tmp_bally2)
print('Racket:',self.rightracketx1,self.rightracketx2,self.rightrackety1,self.rightrackety2)
self.reward = -1
self.scoreRecorded = 1
else:
xshift_ball = self.ball_dx
if self.MissedTheBall: #reset the location of the ball as well as self.ball_dx and self.ball_dy
if tmp_ballx1<self.court_ledge or tmp_ballx2>self.court_redge:
self.NewServe = 1
xshift_ball = 0
yshift_ball = 0
self.ball_dy = random.choice(self.possible_ball_dy)
self.xpos_ball = 0
self.ypos_ball = random.choice(self.possible_ball_ypos)
self.ball_dx = random.choice(self.possible_ball_dx)
return xshift_ball, yshift_ball
def step (self,action):
# one step of game activity
stepsize = self.racket_dy
if action==3:
right_racket_yshift = stepsize
elif action==4:
right_racket_yshift = -stepsize
elif action==1:
right_racket_yshift=0
else: # invalid action means right paddle follows the ball (not done using learning neuronal network model)
ballmidY = self.bally1 + 0.5 * self.ball_height
if ballmidY > self.rightrackety2 - self.wiggle: # if ball is below bottom of racket
right_racket_yshift = stepsize # down
elif self.bally2 < self.rightrackety1 + self.wiggle: # if ball is above top of racket
right_racket_yshift = -stepsize # up
else:
right_racket_yshift = 0
self.createnewframe()
# this rule moves paddle only when it does not overlap with ball along vertical axis +/- self.wiggle
# when using self.wiggle of ~1/2 paddle height, it introduces oscillations in paddle as it tracks the ball
self.moveracket(right_racket_yshift) # this should be always based on Model/User
# needs ball coords, both rackets' coordinates as well as boundaries.
xshift_ball, yshift_ball = self.getNextBallShift(right_racket_yshift)
if self.NewServe==1:
self.ballx1 = self.xpos_ball
self.ballx2 = self.xpos_ball+self.ball_width
self.bally1 = self.court_top+self.ypos_ball
self.bally2 = self.court_top+self.ypos_ball+self.ball_height
self.moveball(xshift_ball, yshift_ball) # this should be computed internally
self.obs = self.obs.astype(np.uint8)
if dconf['simulatedEnvParams']['dodraw']:
self.im.set_data(self.obs)#.astype(np.uint8))
self.drawscore()
plt.pause(0.0001)
return self.obs, self.reward
def drawscore (self):
self.scorestr.set_text('M,H:'+str(self.TotalMissed)+','+str(self.TotalHits))
def reset (self):
print('WARNING: empty reset')
#when the ball is moving in positive X dir then should be checked for hitting the Right racket.
#If the ball hits the right racket: look at the angle and flip the angle.
#else if the ball reaches the right edge, reset the ball.
#else if the ball hits the upper or lower edge, look at the angle and flip the angle.
def testsim (nstep=10000):
# test the simulated pong with nstep
pong = simplePong()
for i in range(nstep):
#randaction = random.choice([3,4,1])
obs, reward = pong.step(-1)#randaction)
if __name__ == '__main__':
testsim()