186 lines
7.2 KiB
Python
186 lines
7.2 KiB
Python
import pygame
|
|
from pygame.locals import *
|
|
import argparse
|
|
|
|
from agent import Agent
|
|
from snake_env import SnakeEnv
|
|
import utils
|
|
import time
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description='CS440 MP4 Snake')
|
|
|
|
parser.add_argument('--human', default = False, action="store_true",
|
|
help='making the game human playable - default False')
|
|
|
|
parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy",
|
|
help='name of model to save if training or to load if evaluating - default q_agent')
|
|
|
|
parser.add_argument('--train_episodes', dest="train_eps", type=int, default=10000,
|
|
help='number of training episodes - default 10000')
|
|
|
|
parser.add_argument('--test_episodes', dest="test_eps", type=int, default=1000,
|
|
help='number of testing episodes - default 1000')
|
|
|
|
parser.add_argument('--show_episodes', dest="show_eps", type=int, default=10,
|
|
help='number of displayed episodes - default 10')
|
|
|
|
parser.add_argument('--window', dest="window", type=int, default=100,
|
|
help='number of episodes to keep running stats for during training - default 100')
|
|
|
|
parser.add_argument('--Ne', dest="Ne", type=int, default=40,
|
|
help='the Ne parameter used in exploration function - default 40')
|
|
|
|
parser.add_argument('--C', dest="C", type=int, default=40,
|
|
help='the C parameter used in learning rate - default 40')
|
|
|
|
parser.add_argument('--gamma', dest="gamma", type=float, default=0.2,
|
|
help='the gamma paramter used in learning rate - default 0.7')
|
|
|
|
parser.add_argument('--snake_head_x', dest="snake_head_x", type=int, default=200,
|
|
help='initialized x position of snake head - default 200')
|
|
|
|
parser.add_argument('--snake_head_y', dest="snake_head_y", type=int, default=200,
|
|
help='initialized y position of snake head - default 200')
|
|
|
|
parser.add_argument('--food_x', dest="food_x", type=int, default=80,
|
|
help='initialized x position of food - default 80')
|
|
|
|
parser.add_argument('--food_y', dest="food_y", type=int, default=80,
|
|
help='initialized y position of food - default 80')
|
|
cfg = parser.parse_args()
|
|
return cfg
|
|
|
|
class Application:
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
|
|
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
|
|
|
|
def execute(self):
|
|
if not self.args.human:
|
|
if self.args.train_eps != 0:
|
|
self.train()
|
|
self.eval()
|
|
self.show_games()
|
|
|
|
def train(self):
|
|
print("Train Phase:")
|
|
self.agent.train()
|
|
window = self.args.window
|
|
self.points_results = []
|
|
first_eat = True
|
|
start = time.time()
|
|
|
|
for game in range(1, self.args.train_eps + 1):
|
|
state = self.env.get_state()
|
|
dead = False
|
|
action = self.agent.choose_action(state, 0, dead)
|
|
while not dead:
|
|
state, points, dead = self.env.step(action)
|
|
|
|
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
|
|
# (see Debug Convenience part on homework 4 web page)
|
|
if first_eat and points == 1:
|
|
self.agent.save_model(utils.CHECKPOINT)
|
|
first_eat = False
|
|
|
|
action = self.agent.choose_action(state, points, dead)
|
|
|
|
|
|
points = self.env.get_points()
|
|
self.points_results.append(points)
|
|
if game % self.args.window == 0:
|
|
print(
|
|
"Games:", len(self.points_results) - window, "-", len(self.points_results),
|
|
"Points (Average:", sum(self.points_results[-window:])/window,
|
|
"Max:", max(self.points_results[-window:]),
|
|
"Min:", min(self.points_results[-window:]),")",
|
|
)
|
|
self.env.reset()
|
|
print("Training takes", time.time() - start, "seconds")
|
|
self.agent.save_model(self.args.model_name)
|
|
|
|
def eval(self):
|
|
print("Evaling Phase:")
|
|
self.agent.eval()
|
|
self.agent.load_model(self.args.model_name)
|
|
points_results = []
|
|
start = time.time()
|
|
|
|
for game in range(1, self.args.test_eps + 1):
|
|
state = self.env.get_state()
|
|
dead = False
|
|
action = self.agent.choose_action(state, 0, dead)
|
|
while not dead:
|
|
state, points, dead = self.env.step(action)
|
|
action = self.agent.choose_action(state, points, dead)
|
|
points = self.env.get_points()
|
|
points_results.append(points)
|
|
self.env.reset()
|
|
|
|
print("Testing takes", time.time() - start, "seconds")
|
|
print("Number of Games:", len(points_results))
|
|
print("Average Points:", sum(points_results)/len(points_results))
|
|
print("Max Points:", max(points_results))
|
|
print("Min Points:", min(points_results))
|
|
|
|
def show_games(self):
|
|
print("Display Games")
|
|
self.env.display()
|
|
pygame.event.pump()
|
|
self.agent.eval()
|
|
points_results = []
|
|
end = False
|
|
for game in range(1, self.args.show_eps + 1):
|
|
state = self.env.get_state()
|
|
dead = False
|
|
action = self.agent.choose_action(state, 0, dead)
|
|
count = 0
|
|
while not dead:
|
|
count +=1
|
|
pygame.event.pump()
|
|
keys = pygame.key.get_pressed()
|
|
if keys[K_ESCAPE] or self.check_quit():
|
|
end = True
|
|
break
|
|
state, points, dead = self.env.step(action)
|
|
# Qlearning agent
|
|
if not self.args.human:
|
|
action = self.agent.choose_action(state, points, dead)
|
|
# for human player
|
|
else:
|
|
for event in pygame.event.get():
|
|
if event.type == pygame.KEYDOWN:
|
|
if event.key == pygame.K_UP:
|
|
action = 2
|
|
elif event.key == pygame.K_DOWN:
|
|
action = 3
|
|
elif event.key == pygame.K_LEFT:
|
|
action = 1
|
|
elif event.key == pygame.K_RIGHT:
|
|
action = 0
|
|
if end:
|
|
break
|
|
self.env.reset()
|
|
points_results.append(points)
|
|
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
|
|
if len(points_results) == 0:
|
|
return
|
|
print("Average Points:", sum(points_results)/len(points_results))
|
|
|
|
def check_quit(self):
|
|
for event in pygame.event.get():
|
|
if event.type == pygame.QUIT:
|
|
return True
|
|
return False
|
|
|
|
|
|
def main():
|
|
cfg = get_args()
|
|
app = Application(cfg)
|
|
app.execute()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|