-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
70 lines (54 loc) · 2.29 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from game_cart import Car
import gym
import numpy as np
import random
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import Adam, SGD, Adadelta, RMSprop
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers import Activation, Flatten
import cv2
from keras.layers.normalization import BatchNormalization
import pygame
from random import randint
clock = pygame.time.Clock()
from collections import deque
import os
path = os.getcwd()
class DQNAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.model = self.load_trained_model()
def build_model(self):
model = Sequential()
model.add(Conv2D(32, kernel_size=(3,3), activation="relu", input_shape=self.state_size))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, kernel_size=(3,3), activation="relu", input_shape=self.state_size))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(32, input_dim=self.state_size, activation='relu'))
model.add(Dense(32, input_dim=self.state_size, activation='relu'))
model.add(Dense(self.action_size, activation='linear'))
return model
def act(self, state):
result = self.model.predict(state)[0].tolist()
result = result.index(max(result))
return result
def load_trained_model(self):
model = self.build_model()
model.load_weights(path+"/success.model")
return model
if __name__ == "__main__":
env = Car()
agent = DQNAgent(env.state_size,env.act_size)
trials = 200
for step in range(trials):
state, _, _ = env.run()
for trial in range(512):
action = agent.act(state)
next_state, reward, done = env.run(action)
state = next_state
pressed = pygame.key.get_pressed()
if pressed[pygame.K_q]: pygame.quit()
clock.tick(60)