-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model.py
123 lines (96 loc) · 3.25 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import cv2
import os
import numpy as np
import time
import pydirectinput
# from alexnet import alexnet
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from grabscreen import grab_screen
from getkeys import key_check
# Enable mixed precision training
tf.keras.mixed_precision.set_global_policy('mixed_float16')
WIDTH = 240
HEIGHT = 180
CHANNELS = 3
LR = 1e-3
EPOCHS = 5
MODEL_NAME = 'nfsmwai-{}-{}-{}-epochs.h5'.format(LR, 'InceptionResNetV2', EPOCHS)
def straight():
pydirectinput.keyDown('w')
pydirectinput.keyUp('a')
pydirectinput.keyUp('d')
def left():
pydirectinput.keyDown('a')
pydirectinput.keyDown('w')
pydirectinput.keyUp('d')
time.sleep(0.01)
pydirectinput.keyUp('a')
def right():
pydirectinput.keyDown('d')
pydirectinput.keyDown('w')
pydirectinput.keyUp('a')
time.sleep(0.01)
pydirectinput.keyUp('d')
def reset():
pydirectinput.keyUp('w')
pydirectinput.keyUp('a')
pydirectinput.keyUp('d')
# Load InceptionResNetV2 with pre-trained ImageNet weights, excluding the top layer
base_model = InceptionResNetV2(weights="imagenet",
include_top=False,
input_shape=(WIDTH, HEIGHT, CHANNELS))
# Add custom layers on top of the base model
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(3, activation='softmax', dtype='float32')(x)
# Define the full model
model = Model(inputs=base_model.input, outputs=predictions)
# Compile the model
model.compile(optimizer=Adam(learning_rate=LR),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.load_weights('../models/' + MODEL_NAME)
def main():
for i in list(range(7))[::-1]:
print(i+1)
time.sleep(1)
last_time = time.time()
paused = False
while True:
if not paused:
screen = grab_screen(region=(0, 30, 800, 630))
# screen = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY)
screen = cv2.resize(screen, (WIDTH, HEIGHT))
# print('Frame took {} seconds'.format(time.time()-last_time))
# last_time = time.time()
preds = model.predict([screen.reshape(-1, WIDTH, HEIGHT, CHANNELS)])[0]
moves = list(np.around(preds))
if moves == [1, 0, 0]:
left()
print(f'Turning LEFT.\nPredictions: {preds}')
elif moves == [0, 1, 0]:
straight()
print(f'Going STRAIGHT.\nPredictions: {preds}')
elif moves == [0, 0, 1]:
right()
print(f'Turning RIGHT.\nPredictions: {preds}')
else:
print('Doing nothing')
keys = key_check()
if 'T' in keys:
if paused:
paused = False
time.sleep(1)
else:
paused = True
pydirectinput.keyUp('a')
pydirectinput.keyUp('w')
pydirectinput.keyUp('d')
time.sleep(1)
main()