forked from NathanKlineInstitute/SMARTAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestRacketPredictions2.py
131 lines (122 loc) · 3.81 KB
/
testRacketPredictions2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import random
from matplotlib import pyplot as plt
import numpy as np
import gym
#from pylab import *
nbsteps = 200000
#breakout-specfici
# courtXRng = (9, 159)
courtXRng = (9,149)
# courtYRng = (32, 189)
courtYRng = (93, 188)
racketYRng = (189,192)
xpos_Ball = -1 #previous location
ypos_Ball = -1
xpos_Ball2 = -1 #current location
ypos_Ball2 = -1
def findobj (img, xrng, yrng):
subimg = img[yrng[0]:yrng[1],xrng[0]:xrng[1],:]
sIC = np.sum(subimg,2)
pixelVal = np.amax(sIC)
sIC[sIC<pixelVal]=0
Obj_inds = []
for i in range(sIC.shape[0]):
for j in range(sIC.shape[1]):
if sIC[i,j]>0:
Obj_inds.append([i,j])
if sIC.shape[0]*sIC.shape[1]==np.shape(Obj_inds)[0] or len(Obj_inds)==0: # if number of elements is equal, no sig obj is found
ypos = -1
xpos = -1
else:
Obj = np.median(Obj_inds,0)
# print(yrng, type(Obj), type(Obj_inds))
# print(np.median(Obj_inds,0))
ypos = np.median(Obj_inds,0)[0]
xpos = np.median(Obj_inds,0)[1]
return xpos, ypos
# env = gym.make('Pong-v0',frameskip=3)
env = gym.make('Breakout-v0', frameskip=3)
env.reset()
#For pong (horizontal game)
def predictBallRacketYIntercept(xpos_Ball,ypos_Ball,xpos_Ball2,ypos_Ball2):
if ((xpos_Ball==-1) or (xpos_Ball2==-1)):
predY = -1
else:
deltax = xpos_Ball2-xpos_Ball
if deltax<=0:
predY = -1
else:
if ypos_Ball<0:
predY = -1
else:
NB_intercept_steps = np.ceil((120.0 - xpos_Ball2)/deltax)
deltay = ypos_Ball2-ypos_Ball
predY_nodeflection = ypos_Ball2 + (NB_intercept_steps*deltay)
if predY_nodeflection<0:
predY = -1*predY_nodeflection
elif predY_nodeflection>160:
predY = predY_nodeflection-160
else:
predY = predY_nodeflection
return predY
#For breakout (vertical game)
def predictBallRacketXIntercept(xpos1, ypos1, xpos2, ypos2):
courtHeight = courtYRng[1] - courtYRng[0]
courtWidth = courtXRng[1] - courtXRng[0]
if ((ypos1==-1) or (ypos2==-1)):
predX = -1
# print ('Error 1')
else:
deltay = ypos2-ypos1
if deltay<=0:
predX = -1
print( deltay, ypos2, ypos1, predX)
# print ('Error 2')
else:
if xpos1<0:
predX = -1
# print ('Error 3')
else:
NB_intercept_steps = np.ceil((courtHeight - ypos2)/deltay)
deltax = xpos2-xpos1
predX_nodeflection = xpos2 + (NB_intercept_steps*deltax)
if predX_nodeflection<0:
predX = -1*predX_nodeflection
# print ('Error 4')
elif predX_nodeflection>courtWidth:
predX = predX_nodeflection-courtWidth
# print ('Error 5')
else:
predX = predX_nodeflection
return predX
observation, reward, done, info = env.step(1)
xpos_Ball2, ypos_Ball2 = findobj (observation, courtXRng, courtYRng)
xpos_Racket2, ypos_Racket2 = findobj (observation, courtXRng, racketYRng)
#breakout-specific
predX = predictBallRacketXIntercept(xpos_Ball,ypos_Ball,xpos_Ball2,ypos_Ball2)
#ion()
for _ in range(nbsteps):
if predX==-1:
caction = np.random.randint(2,4) # 4 is not included, really pick of 2 and 3
# print('Random')
else:
targetX = xpos_Racket2 - predX
if targetX>8:
caction = 3 #left
# print('Target left')
elif targetX<-8:
caction = 2 #right
# print('Target right')
else:
caction = 1 #stay
# print('Target stay')
observation, reward, done, info = env.step(caction)
env.render()
xpos_Ball = xpos_Ball2
ypos_Ball = ypos_Ball2
xpos_Ball2, ypos_Ball2 = findobj (observation, courtXRng, courtYRng)
xpos_Racket2, ypos_Racket2 = findobj (observation, courtXRng, racketYRng)
predX = predictBallRacketXIntercept(xpos_Ball,ypos_Ball,xpos_Ball2,ypos_Ball2)
#imshow(observation,origin='upper'); plot([xpos_Racket2+courtXRng[0]],[predY+courtYRng[0]],'ro')
if done==1:
env.reset()