-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_picture_folder_env.py
94 lines (74 loc) · 3.04 KB
/
custom_picture_folder_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import random
import glob
import cv2
class CustomPictureFolderEnv:
def __init__(self, picture_folder_path='pictures/*', terminal_picture_idx=8):
self.state_picture_index = 0
self.all_pictures = self.read_all_pictures_states(picture_folder_path)
self.terminal_picture_idx = terminal_picture_idx
print('Number of image states: {}'.format(len(self.all_pictures)))
def step(self, action):
if action == 1: # go right
self.state_picture_index -= 1
elif action == 0: # go left
self.state_picture_index += 1
else:
# do nothing
pass
if self.state_picture_index > len(self.all_pictures) - 1:
self.state_picture_index = 0
if self.state_picture_index < 0:
self.state_picture_index = len(self.all_pictures) - 1
done = True if self.state_picture_index == self.terminal_picture_idx else False
if done:
reward = 10
else:
reward = 0
return self.all_pictures[self.state_picture_index], reward, done, None
def reset(self, random_start=False):
if random_start:
self.state_picture_index = random.randint(0, len(self.all_pictures) - 1)
print(self.state_picture_index, len(self.all_pictures))
else:
self.state_picture_index = 0
print('Start image on reset: {}'.format(self.all_picture_paths[self.state_picture_index]))
return self.all_pictures[self.state_picture_index]
def render(self):
cv2.imshow('state image', self.all_pictures[self.state_picture_index])
k = cv2.waitKey(1)
def read_all_pictures_states(self, picture_folder_path):
self.all_picture_paths = glob.glob(picture_folder_path)
all_pictures = [cv2.imread(path) for path in self.all_picture_paths]
# all_pictures = [cv2.resize(image, (80, 60)) for image in all_pictures]
all_pictures_resized = []
for idx, image in enumerate(all_pictures):
try:
all_pictures_resized.append(cv2.resize(image, (80, 60)))
except Exception as e:
print('Couldn\'t resize image with path: {}'.format(self.all_picture_paths[idx]))
print(e)
# Loop through and show all images
for image in all_pictures:
cv2.imshow('image', image)
k = cv2.waitKey(100)
cv2.destroyAllWindows()
return all_pictures_resized
if __name__ == '__main__':
"""
Human-control of environment with A and D keys. Press Q to quit.
"""
env = CustomPictureFolderEnv(picture_folder_path='pictures/*')
state, done = env.reset()
while True:
cv2.imshow('state image', state)
k = cv2.waitKey(1)
action = 2
if k == ord('a'):
print("pressed A, turned left")
action = 0
elif k == ord('d'):
print('pressed D, turned right')
action = 1
elif k == ord('q'):
break
state, reward, done, _ = env.step(action=action)