-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaidriver.py
118 lines (90 loc) · 3.14 KB
/
aidriver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import neat
import pickle
import pygame as pg
from config import AIConfig, GameConfig
from game import Game
from setuppygame import init_pygame, quit_pygame
class AIDriver:
def __init__(self, net):
self.net = net
self.fitness = 0
def drive(self, car, player_events):
gas, turn = self.net.activate(car.radar_lengths)
gas = (gas + 1) * 0.5
turn = (turn + 1) * 0.5
car.drive(gas, turn)
if car.alive:
self.fitness += gas - abs(turn-0.5)
if self.fitness < -10:
car.alive = False
class AIDriverManager:
def __init__(self, config_file):
self.is_graphics = GameConfig.is_graphics
self.config = neat.config.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_file
)
self.p = neat.Population(self.config)
# stdout reporter
self.p.add_reporter(neat.StdOutReporter(True))
self.stats = neat.StatisticsReporter()
self.p.add_reporter(self.stats)
def status(self):
return 'Population size: {pop_size}\nGenration: {generation}\nAlive: {alive}\nBest fitness: {fitness:.2f}'.format(
pop_size=self.p.config.pop_size,
generation=self.p.generation,
alive=self.game.cars_alive,
fitness=self.p.best_genome.fitness if self.p.best_genome is not None else 0
)
def eval_genome(self, genomes_numerated, config):
drivers = []
for num, genome in genomes_numerated:
net = neat.nn.FeedForwardNetwork.create(genome, config)
drivers.append(AIDriver(net))
self.game = Game(self.window, drivers, self.status)
self.game.start(max_time=600)
for driver, t in zip(drivers, genomes_numerated):
num, genome = t
genome.fitness = driver.fitness
def study(self, generation_number):
if self.is_graphics:
self.window = init_pygame()
else:
self.window = None
self.winner = self.p.run(self.eval_genome, generation_number)
if self.is_graphics:
quit_pygame()
return self.winner
def train_ai():
ai_driver_manager = AIDriverManager('neat-config.txt')
ai_driver_manager.study(10)
if AIConfig.save_best_net:
print('Saving best net')
best_net = neat.nn.FeedForwardNetwork.create(
ai_driver_manager.winner,
ai_driver_manager.config
)
with open(AIConfig.best_net_path, 'wb') as f:
pickle.dump(best_net, f)
def load_net(filename):
with open(filename, 'rb') as f:
net = pickle.load(f)
return net
return None
def run_ai(filenames):
drivers = []
for filename in filenames:
net = load_net(filename)
if net is not None:
drivers.append(AIDriver(net))
window = init_pygame()
game = Game(window, drivers)
game.start()
quit_pygame()
def run_best_ai():
run_ai([AIConfig.best_net_path])
if __name__ == '__main__':
run_best_ai()