add snake

This commit is contained in:
JohnJim0816
2020-09-08 13:29:47 +08:00
parent 29f2703ae4
commit f0d19ac14f
10 changed files with 549 additions and 0 deletions

107
codes/snake/agent.py Normal file
View File

@@ -0,0 +1,107 @@
import numpy as np
import utils
import random
import math
class Agent:
def __init__(self, actions, Ne, C, gamma):
self.actions = actions
self.Ne = Ne # used in exploration function
self.C = C
self.gamma = gamma
# Create the Q and N Table to work with
self.Q = utils.create_q_table()
self.N = utils.create_q_table()
self.reset()
def train(self):
self._train = True
def eval(self):
self._train = False
# At the end of training save the trained model
def save_model(self, model_path):
utils.save(model_path, self.Q)
# Load the trained model for evaluation
def load_model(self, model_path):
self.Q = utils.load(model_path)
def reset(self):
self.points = 0
self.s = None
self.a = None
def f(self,u,n):
if n < self.Ne:
return 1
return u
def R(self,points,dead):
if dead:
return -1
elif points > self.points:
return 1
return -0.1
def get_state(self, state):
# [adjoining_wall_x, adjoining_wall_y]
adjoining_wall_x = int(state[0] == utils.WALL_SIZE) + 2 * int(state[0] == utils.DISPLAY_SIZE - utils.WALL_SIZE)
adjoining_wall_y = int(state[1] == utils.WALL_SIZE) + 2 * int(state[1] == utils.DISPLAY_SIZE - utils.WALL_SIZE)
# [food_dir_x, food_dir_y]
food_dir_x = 1 + int(state[0] < state[3]) - int(state[0] == state[3])
food_dir_y = 1 + int(state[1] < state[4]) - int(state[1] == state[4])
# [adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right]
adjoining_body = [(state[0] - body_state[0], state[1] - body_state[1]) for body_state in state[2]]
adjoining_body_top = int([0, utils.GRID_SIZE] in adjoining_body)
adjoining_body_bottom = int([0, -utils.GRID_SIZE] in adjoining_body)
adjoining_body_left = int([utils.GRID_SIZE, 0] in adjoining_body)
adjoining_body_right = int([-utils.GRID_SIZE, 0] in adjoining_body)
return adjoining_wall_x, adjoining_wall_y, food_dir_x, food_dir_y, adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right
def update_tables(self, _state, points, dead):
if self.s:
maxq = max(self.Q[_state])
reward = self.R(points,dead)
alpha = self.C / (self.C + self.N[self.s][self.a])
self.Q[self.s][self.a] += alpha * (reward + self.gamma * maxq - self.Q[self.s][self.a])
self.N[self.s][self.a] += 1.0
def act(self, state, points, dead):
'''
:param state: a list of [snake_head_x, snake_head_y, snake_body, food_x, food_y] from environment.
:param points: float, the current points from environment
:param dead: boolean, if the snake is dead
:return: the index of action. 0,1,2,3 indicates up,down,left,right separately
TODO: write your function here.
Return the index of action the snake needs to take, according to the state and points known from environment.
Tips: you need to discretize the state to the state space defined on the webpage first.
(Note that [adjoining_wall_x=0, adjoining_wall_y=0] is also the case when snake runs out of the 480x480 board)
'''
_state = self.get_state(state)
Qs = self.Q[_state][:]
if self._train:
self.update_tables(_state, points, dead)
if dead:
self.reset()
return
Ns = self.N[_state]
Fs = [self.f(Qs[a], Ns[a]) for a in self.actions]
action = np.argmax(Fs)
self.s = _state
self.a = action
else:
if dead:
self.reset()
return
action = np.argmax(Qs)
self.points = points
return action

BIN
codes/snake/checkpoint.npy Normal file

Binary file not shown.

BIN
codes/snake/checkpoint1.npy Normal file

Binary file not shown.

BIN
codes/snake/checkpoint2.npy Normal file

Binary file not shown.

BIN
codes/snake/checkpoint3.npy Normal file

Binary file not shown.

Binary file not shown.

183
codes/snake/main.py Normal file
View File

@@ -0,0 +1,183 @@
import pygame
from pygame.locals import *
import argparse
from agent import Agent
from snake_env import SnakeEnv
import utils
import time
class Application:
def __init__(self, args):
self.args = args
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
def execute(self):
if not self.args.human:
if self.args.train_eps != 0:
self.train()
self.test()
self.show_games()
def train(self):
print("Train Phase:")
self.agent.train()
window = self.args.window
self.points_results = []
first_eat = True
start = time.time()
for game in range(1, self.args.train_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.act(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
# (see Debug Convenience part on homework 4 web page)
if first_eat and points == 1:
self.agent.save_model(utils.CHECKPOINT)
first_eat = False
action = self.agent.act(state, points, dead)
points = self.env.get_points()
self.points_results.append(points)
if game % self.args.window == 0:
print(
"Games:", len(self.points_results) - window, "-", len(self.points_results),
"Points (Average:", sum(self.points_results[-window:])/window,
"Max:", max(self.points_results[-window:]),
"Min:", min(self.points_results[-window:]),")",
)
self.env.reset()
print("Training takes", time.time() - start, "seconds")
self.agent.save_model(self.args.model_name)
def test(self):
print("Test Phase:")
self.agent.eval()
self.agent.load_model(self.args.model_name)
points_results = []
start = time.time()
for game in range(1, self.args.test_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.act(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
action = self.agent.act(state, points, dead)
points = self.env.get_points()
points_results.append(points)
self.env.reset()
print("Testing takes", time.time() - start, "seconds")
print("Number of Games:", len(points_results))
print("Average Points:", sum(points_results)/len(points_results))
print("Max Points:", max(points_results))
print("Min Points:", min(points_results))
def show_games(self):
print("Display Games")
self.env.display()
pygame.event.pump()
self.agent.eval()
points_results = []
end = False
for game in range(1, self.args.show_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.act(state, 0, dead)
count = 0
while not dead:
count +=1
pygame.event.pump()
keys = pygame.key.get_pressed()
if keys[K_ESCAPE] or self.check_quit():
end = True
break
state, points, dead = self.env.step(action)
# Qlearning agent
if not self.args.human:
action = self.agent.act(state, points, dead)
# for human player
else:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_UP:
action = 2
elif event.key == pygame.K_DOWN:
action = 3
elif event.key == pygame.K_LEFT:
action = 1
elif event.key == pygame.K_RIGHT:
action = 0
if end:
break
self.env.reset()
points_results.append(points)
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
if len(points_results) == 0:
return
print("Average Points:", sum(points_results)/len(points_results))
def check_quit(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
return True
return False
def main():
parser = argparse.ArgumentParser(description='CS440 MP4 Snake')
parser.add_argument('--human', default = False, action="store_true",
help='making the game human playable - default False')
parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy",
help='name of model to save if training or to load if evaluating - default q_agent')
parser.add_argument('--train_episodes', dest="train_eps", type=int, default=10000,
help='number of training episodes - default 10000')
parser.add_argument('--test_episodes', dest="test_eps", type=int, default=1000,
help='number of testing episodes - default 1000')
parser.add_argument('--show_episodes', dest="show_eps", type=int, default=10,
help='number of displayed episodes - default 10')
parser.add_argument('--window', dest="window", type=int, default=100,
help='number of episodes to keep running stats for during training - default 100')
parser.add_argument('--Ne', dest="Ne", type=int, default=40,
help='the Ne parameter used in exploration function - default 40')
parser.add_argument('--C', dest="C", type=int, default=40,
help='the C parameter used in learning rate - default 40')
parser.add_argument('--gamma', dest="gamma", type=float, default=0.2,
help='the gamma paramter used in learning rate - default 0.7')
parser.add_argument('--snake_head_x', dest="snake_head_x", type=int, default=200,
help='initialized x position of snake head - default 200')
parser.add_argument('--snake_head_y', dest="snake_head_y", type=int, default=200,
help='initialized y position of snake head - default 200')
parser.add_argument('--food_x', dest="food_x", type=int, default=80,
help='initialized x position of food - default 80')
parser.add_argument('--food_y', dest="food_y", type=int, default=80,
help='initialized y position of food - default 80')
args = parser.parse_args()
app = Application(args)
app.execute()
if __name__ == "__main__":
main()

BIN
codes/snake/q_agent.npy Normal file

Binary file not shown.

204
codes/snake/snake_env.py Normal file
View File

@@ -0,0 +1,204 @@
import random
import pygame
import utils
class SnakeEnv:
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
self.game = Snake(snake_head_x, snake_head_y, food_x, food_y)
self.render = False
def get_actions(self):
return self.game.get_actions()
def reset(self):
return self.game.reset()
def get_points(self):
return self.game.get_points()
def get_state(self):
return self.game.get_state()
def step(self, action):
state, points, dead = self.game.step(action)
if self.render:
self.draw(state, points, dead)
return state, points, dead
def draw(self, state, points, dead):
snake_head_x, snake_head_y, snake_body, food_x, food_y = state
self.display.fill(utils.BLUE)
pygame.draw.rect( self.display, utils.BLACK,
[
utils.GRID_SIZE,
utils.GRID_SIZE,
utils.DISPLAY_SIZE - utils.GRID_SIZE * 2,
utils.DISPLAY_SIZE - utils.GRID_SIZE * 2
])
# draw snake head
pygame.draw.rect(
self.display,
utils.GREEN,
[
snake_head_x,
snake_head_y,
utils.GRID_SIZE,
utils.GRID_SIZE
],
3
)
# draw snake body
for seg in snake_body:
pygame.draw.rect(
self.display,
utils.GREEN,
[
seg[0],
seg[1],
utils.GRID_SIZE,
utils.GRID_SIZE,
],
1
)
# draw food
pygame.draw.rect(
self.display,
utils.RED,
[
food_x,
food_y,
utils.GRID_SIZE,
utils.GRID_SIZE
]
)
text_surface = self.font.render("Points: " + str(points), True, utils.BLACK)
text_rect = text_surface.get_rect()
text_rect.center = ((280),(25))
self.display.blit(text_surface, text_rect)
pygame.display.flip()
if dead:
# slow clock if dead
self.clock.tick(1)
else:
self.clock.tick(5)
return
def display(self):
pygame.init()
pygame.display.set_caption('MP4: Snake')
self.clock = pygame.time.Clock()
pygame.font.init()
self.font = pygame.font.Font(pygame.font.get_default_font(), 15)
self.display = pygame.display.set_mode((utils.DISPLAY_SIZE, utils.DISPLAY_SIZE), pygame.HWSURFACE)
self.draw(self.game.get_state(), self.game.get_points(), False)
self.render = True
class Snake:
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
self.init_snake_head_x = snake_head_x
self.init_snake_head_y = snake_head_y
self.init_food_x = food_x
self.init_food_y = food_y
self.reset()
def reset(self):
self.points = 0
self.snake_head_x = self.init_snake_head_x
self.snake_head_y = self.init_snake_head_y
self.snake_body = []
self.food_x = self.init_food_x
self.food_y = self.init_food_y
def get_points(self):
return self.points
def get_actions(self):
return [0, 1, 2, 3]
def get_state(self):
return [
self.snake_head_x,
self.snake_head_y,
self.snake_body,
self.food_x,
self.food_y
]
def move(self, action):
delta_x = delta_y = 0
if action == 0:
delta_x = utils.GRID_SIZE
elif action == 1:
delta_x = - utils.GRID_SIZE
elif action == 2:
delta_y = - utils.GRID_SIZE
elif action == 3:
delta_y = utils.GRID_SIZE
old_body_head = None
if len(self.snake_body) == 1:
old_body_head = self.snake_body[0]
self.snake_body.append((self.snake_head_x, self.snake_head_y))
self.snake_head_x += delta_x
self.snake_head_y += delta_y
if len(self.snake_body) > self.points:
del(self.snake_body[0])
self.handle_eatfood()
# colliding with the snake body or going backwards while its body length
# greater than 1
if len(self.snake_body) >= 1:
for seg in self.snake_body:
if self.snake_head_x == seg[0] and self.snake_head_y == seg[1]:
return True
# moving towards body direction, not allowing snake to go backwards while
# its body length is 1
if len(self.snake_body) == 1:
if old_body_head == (self.snake_head_x, self.snake_head_y):
return True
# collide with the wall
if (self.snake_head_x < utils.GRID_SIZE or self.snake_head_y < utils.GRID_SIZE or
self.snake_head_x + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE or self.snake_head_y + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE):
return True
return False
def step(self, action):
is_dead = self.move(action)
return self.get_state(), self.get_points(), is_dead
def handle_eatfood(self):
if (self.snake_head_x == self.food_x) and (self.snake_head_y == self.food_y):
self.random_food()
self.points += 1
def random_food(self):
max_x = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
max_y = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
while self.check_food_on_snake():
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
def check_food_on_snake(self):
if self.food_x == self.snake_head_x and self.food_y == self.snake_head_y:
return True
for seg in self.snake_body:
if self.food_x == seg[0] and self.food_y == seg[1]:
return True
return False

55
codes/snake/utils.py Normal file
View File

@@ -0,0 +1,55 @@
import numpy as np
DISPLAY_SIZE = 560
GRID_SIZE = 40
WALL_SIZE = 40
WHITE = (255, 255, 255)
RED = (255, 0, 0)
BLUE = (72, 61, 139)
BLACK = (0, 0, 0)
GREEN = (0, 255, 0)
NUM_ADJOINING_WALL_X_STATES=3
NUM_ADJOINING_WALL_Y_STATES=3
NUM_FOOD_DIR_X=3
NUM_FOOD_DIR_Y=3
NUM_ADJOINING_BODY_TOP_STATES=2
NUM_ADJOINING_BODY_BOTTOM_STATES=2
NUM_ADJOINING_BODY_LEFT_STATES=2
NUM_ADJOINING_BODY_RIGHT_STATES=2
NUM_ACTIONS = 4
CHECKPOINT = 'checkpoint.npy'
def create_q_table():
return np.zeros((NUM_ADJOINING_WALL_X_STATES, NUM_ADJOINING_WALL_Y_STATES, NUM_FOOD_DIR_X, NUM_FOOD_DIR_Y,
NUM_ADJOINING_BODY_TOP_STATES, NUM_ADJOINING_BODY_BOTTOM_STATES, NUM_ADJOINING_BODY_LEFT_STATES,
NUM_ADJOINING_BODY_RIGHT_STATES, NUM_ACTIONS))
def sanity_check(arr):
if (type(arr) is np.ndarray and
arr.shape==(NUM_ADJOINING_WALL_X_STATES, NUM_ADJOINING_WALL_Y_STATES, NUM_FOOD_DIR_X, NUM_FOOD_DIR_Y,
NUM_ADJOINING_BODY_TOP_STATES, NUM_ADJOINING_BODY_BOTTOM_STATES, NUM_ADJOINING_BODY_LEFT_STATES,
NUM_ADJOINING_BODY_RIGHT_STATES,NUM_ACTIONS)):
return True
else:
return False
def save(filename, arr):
if sanity_check(arr):
np.save(filename,arr)
return True
else:
print("Failed to save model")
return False
def load(filename):
try:
arr = np.load(filename)
if sanity_check(arr):
print("Loaded model successfully")
return arr
print("Model loaded is not in the required format")
return None
except:
print("Filename doesnt exist")
return None