add snake
This commit is contained in:
107
codes/snake/agent.py
Normal file
107
codes/snake/agent.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import numpy as np
|
||||
import utils
|
||||
import random
|
||||
import math
|
||||
|
||||
|
||||
class Agent:
|
||||
|
||||
def __init__(self, actions, Ne, C, gamma):
|
||||
self.actions = actions
|
||||
self.Ne = Ne # used in exploration function
|
||||
self.C = C
|
||||
self.gamma = gamma
|
||||
|
||||
# Create the Q and N Table to work with
|
||||
self.Q = utils.create_q_table()
|
||||
self.N = utils.create_q_table()
|
||||
self.reset()
|
||||
|
||||
def train(self):
|
||||
self._train = True
|
||||
|
||||
def eval(self):
|
||||
self._train = False
|
||||
|
||||
# At the end of training save the trained model
|
||||
def save_model(self, model_path):
|
||||
utils.save(model_path, self.Q)
|
||||
|
||||
# Load the trained model for evaluation
|
||||
def load_model(self, model_path):
|
||||
self.Q = utils.load(model_path)
|
||||
|
||||
def reset(self):
|
||||
self.points = 0
|
||||
self.s = None
|
||||
self.a = None
|
||||
|
||||
def f(self,u,n):
|
||||
if n < self.Ne:
|
||||
return 1
|
||||
return u
|
||||
|
||||
def R(self,points,dead):
|
||||
if dead:
|
||||
return -1
|
||||
elif points > self.points:
|
||||
return 1
|
||||
return -0.1
|
||||
|
||||
def get_state(self, state):
|
||||
# [adjoining_wall_x, adjoining_wall_y]
|
||||
adjoining_wall_x = int(state[0] == utils.WALL_SIZE) + 2 * int(state[0] == utils.DISPLAY_SIZE - utils.WALL_SIZE)
|
||||
adjoining_wall_y = int(state[1] == utils.WALL_SIZE) + 2 * int(state[1] == utils.DISPLAY_SIZE - utils.WALL_SIZE)
|
||||
# [food_dir_x, food_dir_y]
|
||||
food_dir_x = 1 + int(state[0] < state[3]) - int(state[0] == state[3])
|
||||
food_dir_y = 1 + int(state[1] < state[4]) - int(state[1] == state[4])
|
||||
# [adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right]
|
||||
adjoining_body = [(state[0] - body_state[0], state[1] - body_state[1]) for body_state in state[2]]
|
||||
adjoining_body_top = int([0, utils.GRID_SIZE] in adjoining_body)
|
||||
adjoining_body_bottom = int([0, -utils.GRID_SIZE] in adjoining_body)
|
||||
adjoining_body_left = int([utils.GRID_SIZE, 0] in adjoining_body)
|
||||
adjoining_body_right = int([-utils.GRID_SIZE, 0] in adjoining_body)
|
||||
return adjoining_wall_x, adjoining_wall_y, food_dir_x, food_dir_y, adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right
|
||||
|
||||
|
||||
def update_tables(self, _state, points, dead):
|
||||
if self.s:
|
||||
maxq = max(self.Q[_state])
|
||||
reward = self.R(points,dead)
|
||||
alpha = self.C / (self.C + self.N[self.s][self.a])
|
||||
self.Q[self.s][self.a] += alpha * (reward + self.gamma * maxq - self.Q[self.s][self.a])
|
||||
self.N[self.s][self.a] += 1.0
|
||||
|
||||
def act(self, state, points, dead):
|
||||
'''
|
||||
:param state: a list of [snake_head_x, snake_head_y, snake_body, food_x, food_y] from environment.
|
||||
:param points: float, the current points from environment
|
||||
:param dead: boolean, if the snake is dead
|
||||
:return: the index of action. 0,1,2,3 indicates up,down,left,right separately
|
||||
TODO: write your function here.
|
||||
Return the index of action the snake needs to take, according to the state and points known from environment.
|
||||
Tips: you need to discretize the state to the state space defined on the webpage first.
|
||||
(Note that [adjoining_wall_x=0, adjoining_wall_y=0] is also the case when snake runs out of the 480x480 board)
|
||||
'''
|
||||
|
||||
_state = self.get_state(state)
|
||||
Qs = self.Q[_state][:]
|
||||
|
||||
if self._train:
|
||||
self.update_tables(_state, points, dead)
|
||||
if dead:
|
||||
self.reset()
|
||||
return
|
||||
Ns = self.N[_state]
|
||||
Fs = [self.f(Qs[a], Ns[a]) for a in self.actions]
|
||||
action = np.argmax(Fs)
|
||||
self.s = _state
|
||||
self.a = action
|
||||
else:
|
||||
if dead:
|
||||
self.reset()
|
||||
return
|
||||
action = np.argmax(Qs)
|
||||
|
||||
self.points = points
|
||||
return action
|
||||
BIN
codes/snake/checkpoint.npy
Normal file
BIN
codes/snake/checkpoint.npy
Normal file
Binary file not shown.
BIN
codes/snake/checkpoint1.npy
Normal file
BIN
codes/snake/checkpoint1.npy
Normal file
Binary file not shown.
BIN
codes/snake/checkpoint2.npy
Normal file
BIN
codes/snake/checkpoint2.npy
Normal file
Binary file not shown.
BIN
codes/snake/checkpoint3.npy
Normal file
BIN
codes/snake/checkpoint3.npy
Normal file
Binary file not shown.
BIN
codes/snake/example_assignment_and_report2.pdf
Normal file
BIN
codes/snake/example_assignment_and_report2.pdf
Normal file
Binary file not shown.
183
codes/snake/main.py
Normal file
183
codes/snake/main.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import pygame
|
||||
from pygame.locals import *
|
||||
import argparse
|
||||
|
||||
from agent import Agent
|
||||
from snake_env import SnakeEnv
|
||||
import utils
|
||||
import time
|
||||
|
||||
class Application:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
|
||||
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
|
||||
|
||||
def execute(self):
|
||||
if not self.args.human:
|
||||
if self.args.train_eps != 0:
|
||||
self.train()
|
||||
self.test()
|
||||
self.show_games()
|
||||
|
||||
def train(self):
|
||||
print("Train Phase:")
|
||||
self.agent.train()
|
||||
window = self.args.window
|
||||
self.points_results = []
|
||||
first_eat = True
|
||||
start = time.time()
|
||||
|
||||
for game in range(1, self.args.train_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.act(state, 0, dead)
|
||||
while not dead:
|
||||
state, points, dead = self.env.step(action)
|
||||
|
||||
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
|
||||
# (see Debug Convenience part on homework 4 web page)
|
||||
if first_eat and points == 1:
|
||||
self.agent.save_model(utils.CHECKPOINT)
|
||||
first_eat = False
|
||||
|
||||
action = self.agent.act(state, points, dead)
|
||||
|
||||
|
||||
points = self.env.get_points()
|
||||
self.points_results.append(points)
|
||||
if game % self.args.window == 0:
|
||||
print(
|
||||
"Games:", len(self.points_results) - window, "-", len(self.points_results),
|
||||
"Points (Average:", sum(self.points_results[-window:])/window,
|
||||
"Max:", max(self.points_results[-window:]),
|
||||
"Min:", min(self.points_results[-window:]),")",
|
||||
)
|
||||
self.env.reset()
|
||||
print("Training takes", time.time() - start, "seconds")
|
||||
self.agent.save_model(self.args.model_name)
|
||||
|
||||
def test(self):
|
||||
print("Test Phase:")
|
||||
self.agent.eval()
|
||||
self.agent.load_model(self.args.model_name)
|
||||
points_results = []
|
||||
start = time.time()
|
||||
|
||||
for game in range(1, self.args.test_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.act(state, 0, dead)
|
||||
while not dead:
|
||||
state, points, dead = self.env.step(action)
|
||||
action = self.agent.act(state, points, dead)
|
||||
points = self.env.get_points()
|
||||
points_results.append(points)
|
||||
self.env.reset()
|
||||
|
||||
print("Testing takes", time.time() - start, "seconds")
|
||||
print("Number of Games:", len(points_results))
|
||||
print("Average Points:", sum(points_results)/len(points_results))
|
||||
print("Max Points:", max(points_results))
|
||||
print("Min Points:", min(points_results))
|
||||
|
||||
def show_games(self):
|
||||
print("Display Games")
|
||||
self.env.display()
|
||||
pygame.event.pump()
|
||||
self.agent.eval()
|
||||
points_results = []
|
||||
end = False
|
||||
for game in range(1, self.args.show_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.act(state, 0, dead)
|
||||
count = 0
|
||||
while not dead:
|
||||
count +=1
|
||||
pygame.event.pump()
|
||||
keys = pygame.key.get_pressed()
|
||||
if keys[K_ESCAPE] or self.check_quit():
|
||||
end = True
|
||||
break
|
||||
state, points, dead = self.env.step(action)
|
||||
# Qlearning agent
|
||||
if not self.args.human:
|
||||
action = self.agent.act(state, points, dead)
|
||||
# for human player
|
||||
else:
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_UP:
|
||||
action = 2
|
||||
elif event.key == pygame.K_DOWN:
|
||||
action = 3
|
||||
elif event.key == pygame.K_LEFT:
|
||||
action = 1
|
||||
elif event.key == pygame.K_RIGHT:
|
||||
action = 0
|
||||
if end:
|
||||
break
|
||||
self.env.reset()
|
||||
points_results.append(points)
|
||||
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
|
||||
if len(points_results) == 0:
|
||||
return
|
||||
print("Average Points:", sum(points_results)/len(points_results))
|
||||
|
||||
def check_quit(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='CS440 MP4 Snake')
|
||||
|
||||
parser.add_argument('--human', default = False, action="store_true",
|
||||
help='making the game human playable - default False')
|
||||
|
||||
parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy",
|
||||
help='name of model to save if training or to load if evaluating - default q_agent')
|
||||
|
||||
parser.add_argument('--train_episodes', dest="train_eps", type=int, default=10000,
|
||||
help='number of training episodes - default 10000')
|
||||
|
||||
parser.add_argument('--test_episodes', dest="test_eps", type=int, default=1000,
|
||||
help='number of testing episodes - default 1000')
|
||||
|
||||
parser.add_argument('--show_episodes', dest="show_eps", type=int, default=10,
|
||||
help='number of displayed episodes - default 10')
|
||||
|
||||
parser.add_argument('--window', dest="window", type=int, default=100,
|
||||
help='number of episodes to keep running stats for during training - default 100')
|
||||
|
||||
parser.add_argument('--Ne', dest="Ne", type=int, default=40,
|
||||
help='the Ne parameter used in exploration function - default 40')
|
||||
|
||||
parser.add_argument('--C', dest="C", type=int, default=40,
|
||||
help='the C parameter used in learning rate - default 40')
|
||||
|
||||
parser.add_argument('--gamma', dest="gamma", type=float, default=0.2,
|
||||
help='the gamma paramter used in learning rate - default 0.7')
|
||||
|
||||
parser.add_argument('--snake_head_x', dest="snake_head_x", type=int, default=200,
|
||||
help='initialized x position of snake head - default 200')
|
||||
|
||||
parser.add_argument('--snake_head_y', dest="snake_head_y", type=int, default=200,
|
||||
help='initialized y position of snake head - default 200')
|
||||
|
||||
parser.add_argument('--food_x', dest="food_x", type=int, default=80,
|
||||
help='initialized x position of food - default 80')
|
||||
|
||||
parser.add_argument('--food_y', dest="food_y", type=int, default=80,
|
||||
help='initialized y position of food - default 80')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
app = Application(args)
|
||||
app.execute()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
codes/snake/q_agent.npy
Normal file
BIN
codes/snake/q_agent.npy
Normal file
Binary file not shown.
204
codes/snake/snake_env.py
Normal file
204
codes/snake/snake_env.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import random
|
||||
import pygame
|
||||
import utils
|
||||
|
||||
class SnakeEnv:
|
||||
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
|
||||
self.game = Snake(snake_head_x, snake_head_y, food_x, food_y)
|
||||
self.render = False
|
||||
|
||||
def get_actions(self):
|
||||
return self.game.get_actions()
|
||||
|
||||
def reset(self):
|
||||
return self.game.reset()
|
||||
|
||||
def get_points(self):
|
||||
return self.game.get_points()
|
||||
|
||||
def get_state(self):
|
||||
return self.game.get_state()
|
||||
|
||||
def step(self, action):
|
||||
state, points, dead = self.game.step(action)
|
||||
if self.render:
|
||||
self.draw(state, points, dead)
|
||||
return state, points, dead
|
||||
|
||||
def draw(self, state, points, dead):
|
||||
snake_head_x, snake_head_y, snake_body, food_x, food_y = state
|
||||
self.display.fill(utils.BLUE)
|
||||
pygame.draw.rect( self.display, utils.BLACK,
|
||||
[
|
||||
utils.GRID_SIZE,
|
||||
utils.GRID_SIZE,
|
||||
utils.DISPLAY_SIZE - utils.GRID_SIZE * 2,
|
||||
utils.DISPLAY_SIZE - utils.GRID_SIZE * 2
|
||||
])
|
||||
|
||||
# draw snake head
|
||||
pygame.draw.rect(
|
||||
self.display,
|
||||
utils.GREEN,
|
||||
[
|
||||
snake_head_x,
|
||||
snake_head_y,
|
||||
utils.GRID_SIZE,
|
||||
utils.GRID_SIZE
|
||||
],
|
||||
3
|
||||
)
|
||||
# draw snake body
|
||||
for seg in snake_body:
|
||||
pygame.draw.rect(
|
||||
self.display,
|
||||
utils.GREEN,
|
||||
[
|
||||
seg[0],
|
||||
seg[1],
|
||||
utils.GRID_SIZE,
|
||||
utils.GRID_SIZE,
|
||||
],
|
||||
1
|
||||
)
|
||||
# draw food
|
||||
pygame.draw.rect(
|
||||
self.display,
|
||||
utils.RED,
|
||||
[
|
||||
food_x,
|
||||
food_y,
|
||||
utils.GRID_SIZE,
|
||||
utils.GRID_SIZE
|
||||
]
|
||||
)
|
||||
|
||||
text_surface = self.font.render("Points: " + str(points), True, utils.BLACK)
|
||||
text_rect = text_surface.get_rect()
|
||||
text_rect.center = ((280),(25))
|
||||
self.display.blit(text_surface, text_rect)
|
||||
pygame.display.flip()
|
||||
if dead:
|
||||
# slow clock if dead
|
||||
self.clock.tick(1)
|
||||
else:
|
||||
self.clock.tick(5)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def display(self):
|
||||
pygame.init()
|
||||
pygame.display.set_caption('MP4: Snake')
|
||||
self.clock = pygame.time.Clock()
|
||||
pygame.font.init()
|
||||
|
||||
self.font = pygame.font.Font(pygame.font.get_default_font(), 15)
|
||||
self.display = pygame.display.set_mode((utils.DISPLAY_SIZE, utils.DISPLAY_SIZE), pygame.HWSURFACE)
|
||||
self.draw(self.game.get_state(), self.game.get_points(), False)
|
||||
self.render = True
|
||||
|
||||
class Snake:
|
||||
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
|
||||
self.init_snake_head_x = snake_head_x
|
||||
self.init_snake_head_y = snake_head_y
|
||||
self.init_food_x = food_x
|
||||
self.init_food_y = food_y
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.points = 0
|
||||
self.snake_head_x = self.init_snake_head_x
|
||||
self.snake_head_y = self.init_snake_head_y
|
||||
self.snake_body = []
|
||||
self.food_x = self.init_food_x
|
||||
self.food_y = self.init_food_y
|
||||
|
||||
def get_points(self):
|
||||
return self.points
|
||||
|
||||
def get_actions(self):
|
||||
return [0, 1, 2, 3]
|
||||
|
||||
def get_state(self):
|
||||
return [
|
||||
self.snake_head_x,
|
||||
self.snake_head_y,
|
||||
self.snake_body,
|
||||
self.food_x,
|
||||
self.food_y
|
||||
]
|
||||
|
||||
def move(self, action):
|
||||
delta_x = delta_y = 0
|
||||
if action == 0:
|
||||
delta_x = utils.GRID_SIZE
|
||||
elif action == 1:
|
||||
delta_x = - utils.GRID_SIZE
|
||||
elif action == 2:
|
||||
delta_y = - utils.GRID_SIZE
|
||||
elif action == 3:
|
||||
delta_y = utils.GRID_SIZE
|
||||
|
||||
old_body_head = None
|
||||
if len(self.snake_body) == 1:
|
||||
old_body_head = self.snake_body[0]
|
||||
self.snake_body.append((self.snake_head_x, self.snake_head_y))
|
||||
self.snake_head_x += delta_x
|
||||
self.snake_head_y += delta_y
|
||||
|
||||
if len(self.snake_body) > self.points:
|
||||
del(self.snake_body[0])
|
||||
|
||||
self.handle_eatfood()
|
||||
|
||||
# colliding with the snake body or going backwards while its body length
|
||||
# greater than 1
|
||||
if len(self.snake_body) >= 1:
|
||||
for seg in self.snake_body:
|
||||
if self.snake_head_x == seg[0] and self.snake_head_y == seg[1]:
|
||||
return True
|
||||
|
||||
# moving towards body direction, not allowing snake to go backwards while
|
||||
# its body length is 1
|
||||
if len(self.snake_body) == 1:
|
||||
if old_body_head == (self.snake_head_x, self.snake_head_y):
|
||||
return True
|
||||
|
||||
# collide with the wall
|
||||
if (self.snake_head_x < utils.GRID_SIZE or self.snake_head_y < utils.GRID_SIZE or
|
||||
self.snake_head_x + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE or self.snake_head_y + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def step(self, action):
|
||||
is_dead = self.move(action)
|
||||
return self.get_state(), self.get_points(), is_dead
|
||||
|
||||
def handle_eatfood(self):
|
||||
if (self.snake_head_x == self.food_x) and (self.snake_head_y == self.food_y):
|
||||
self.random_food()
|
||||
self.points += 1
|
||||
|
||||
|
||||
def random_food(self):
|
||||
max_x = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
|
||||
max_y = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
|
||||
|
||||
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
|
||||
while self.check_food_on_snake():
|
||||
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
|
||||
def check_food_on_snake(self):
|
||||
if self.food_x == self.snake_head_x and self.food_y == self.snake_head_y:
|
||||
return True
|
||||
for seg in self.snake_body:
|
||||
if self.food_x == seg[0] and self.food_y == seg[1]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
55
codes/snake/utils.py
Normal file
55
codes/snake/utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import numpy as np
|
||||
DISPLAY_SIZE = 560
|
||||
GRID_SIZE = 40
|
||||
WALL_SIZE = 40
|
||||
WHITE = (255, 255, 255)
|
||||
RED = (255, 0, 0)
|
||||
BLUE = (72, 61, 139)
|
||||
BLACK = (0, 0, 0)
|
||||
GREEN = (0, 255, 0)
|
||||
|
||||
NUM_ADJOINING_WALL_X_STATES=3
|
||||
NUM_ADJOINING_WALL_Y_STATES=3
|
||||
NUM_FOOD_DIR_X=3
|
||||
NUM_FOOD_DIR_Y=3
|
||||
NUM_ADJOINING_BODY_TOP_STATES=2
|
||||
NUM_ADJOINING_BODY_BOTTOM_STATES=2
|
||||
NUM_ADJOINING_BODY_LEFT_STATES=2
|
||||
NUM_ADJOINING_BODY_RIGHT_STATES=2
|
||||
NUM_ACTIONS = 4
|
||||
|
||||
CHECKPOINT = 'checkpoint.npy'
|
||||
|
||||
def create_q_table():
|
||||
return np.zeros((NUM_ADJOINING_WALL_X_STATES, NUM_ADJOINING_WALL_Y_STATES, NUM_FOOD_DIR_X, NUM_FOOD_DIR_Y,
|
||||
NUM_ADJOINING_BODY_TOP_STATES, NUM_ADJOINING_BODY_BOTTOM_STATES, NUM_ADJOINING_BODY_LEFT_STATES,
|
||||
NUM_ADJOINING_BODY_RIGHT_STATES, NUM_ACTIONS))
|
||||
|
||||
def sanity_check(arr):
|
||||
if (type(arr) is np.ndarray and
|
||||
arr.shape==(NUM_ADJOINING_WALL_X_STATES, NUM_ADJOINING_WALL_Y_STATES, NUM_FOOD_DIR_X, NUM_FOOD_DIR_Y,
|
||||
NUM_ADJOINING_BODY_TOP_STATES, NUM_ADJOINING_BODY_BOTTOM_STATES, NUM_ADJOINING_BODY_LEFT_STATES,
|
||||
NUM_ADJOINING_BODY_RIGHT_STATES,NUM_ACTIONS)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def save(filename, arr):
|
||||
if sanity_check(arr):
|
||||
np.save(filename,arr)
|
||||
return True
|
||||
else:
|
||||
print("Failed to save model")
|
||||
return False
|
||||
|
||||
def load(filename):
|
||||
try:
|
||||
arr = np.load(filename)
|
||||
if sanity_check(arr):
|
||||
print("Loaded model successfully")
|
||||
return arr
|
||||
print("Model loaded is not in the required format")
|
||||
return None
|
||||
except:
|
||||
print("Filename doesnt exist")
|
||||
return None
|
||||
Reference in New Issue
Block a user