Files
easy-rl/codes/envs/snake/main.py
JohnJim0816 e4690ac89f update
2021-04-16 14:59:23 +08:00

186 lines
7.2 KiB
Python

import pygame
from pygame.locals import *
import argparse
from agent import Agent
from snake_env import SnakeEnv
import utils
import time
def get_args():
parser = argparse.ArgumentParser(description='CS440 MP4 Snake')
parser.add_argument('--human', default = False, action="store_true",
help='making the game human playable - default False')
parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy",
help='name of model to save if training or to load if evaluating - default q_agent')
parser.add_argument('--train_episodes', dest="train_eps", type=int, default=10000,
help='number of training episodes - default 10000')
parser.add_argument('--test_episodes', dest="test_eps", type=int, default=1000,
help='number of testing episodes - default 1000')
parser.add_argument('--show_episodes', dest="show_eps", type=int, default=10,
help='number of displayed episodes - default 10')
parser.add_argument('--window', dest="window", type=int, default=100,
help='number of episodes to keep running stats for during training - default 100')
parser.add_argument('--Ne', dest="Ne", type=int, default=40,
help='the Ne parameter used in exploration function - default 40')
parser.add_argument('--C', dest="C", type=int, default=40,
help='the C parameter used in learning rate - default 40')
parser.add_argument('--gamma', dest="gamma", type=float, default=0.2,
help='the gamma paramter used in learning rate - default 0.7')
parser.add_argument('--snake_head_x', dest="snake_head_x", type=int, default=200,
help='initialized x position of snake head - default 200')
parser.add_argument('--snake_head_y', dest="snake_head_y", type=int, default=200,
help='initialized y position of snake head - default 200')
parser.add_argument('--food_x', dest="food_x", type=int, default=80,
help='initialized x position of food - default 80')
parser.add_argument('--food_y', dest="food_y", type=int, default=80,
help='initialized y position of food - default 80')
cfg = parser.parse_args()
return cfg
class Application:
def __init__(self, args):
self.args = args
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
def execute(self):
if not self.args.human:
if self.args.train_eps != 0:
self.train()
self.eval()
self.show_games()
def train(self):
print("Train Phase:")
self.agent.train()
window = self.args.window
self.points_results = []
first_eat = True
start = time.time()
for game in range(1, self.args.train_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.choose_action(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
# (see Debug Convenience part on homework 4 web page)
if first_eat and points == 1:
self.agent.save_model(utils.CHECKPOINT)
first_eat = False
action = self.agent.choose_action(state, points, dead)
points = self.env.get_points()
self.points_results.append(points)
if game % self.args.window == 0:
print(
"Games:", len(self.points_results) - window, "-", len(self.points_results),
"Points (Average:", sum(self.points_results[-window:])/window,
"Max:", max(self.points_results[-window:]),
"Min:", min(self.points_results[-window:]),")",
)
self.env.reset()
print("Training takes", time.time() - start, "seconds")
self.agent.save_model(self.args.model_name)
def eval(self):
print("Evaling Phase:")
self.agent.eval()
self.agent.load_model(self.args.model_name)
points_results = []
start = time.time()
for game in range(1, self.args.test_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.choose_action(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
action = self.agent.choose_action(state, points, dead)
points = self.env.get_points()
points_results.append(points)
self.env.reset()
print("Testing takes", time.time() - start, "seconds")
print("Number of Games:", len(points_results))
print("Average Points:", sum(points_results)/len(points_results))
print("Max Points:", max(points_results))
print("Min Points:", min(points_results))
def show_games(self):
print("Display Games")
self.env.display()
pygame.event.pump()
self.agent.eval()
points_results = []
end = False
for game in range(1, self.args.show_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.choose_action(state, 0, dead)
count = 0
while not dead:
count +=1
pygame.event.pump()
keys = pygame.key.get_pressed()
if keys[K_ESCAPE] or self.check_quit():
end = True
break
state, points, dead = self.env.step(action)
# Qlearning agent
if not self.args.human:
action = self.agent.choose_action(state, points, dead)
# for human player
else:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_UP:
action = 2
elif event.key == pygame.K_DOWN:
action = 3
elif event.key == pygame.K_LEFT:
action = 1
elif event.key == pygame.K_RIGHT:
action = 0
if end:
break
self.env.reset()
points_results.append(points)
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
if len(points_results) == 0:
return
print("Average Points:", sum(points_results)/len(points_results))
def check_quit(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
return True
return False
def main():
cfg = get_args()
app = Application(cfg)
app.execute()
if __name__ == "__main__":
main()