diff --git a/codes/snake/agent.py b/codes/snake/agent.py new file mode 100644 index 0000000..1c05b64 --- /dev/null +++ b/codes/snake/agent.py @@ -0,0 +1,107 @@ +import numpy as np +import utils +import random +import math + + +class Agent: + + def __init__(self, actions, Ne, C, gamma): + self.actions = actions + self.Ne = Ne # used in exploration function + self.C = C + self.gamma = gamma + + # Create the Q and N Table to work with + self.Q = utils.create_q_table() + self.N = utils.create_q_table() + self.reset() + + def train(self): + self._train = True + + def eval(self): + self._train = False + + # At the end of training save the trained model + def save_model(self, model_path): + utils.save(model_path, self.Q) + + # Load the trained model for evaluation + def load_model(self, model_path): + self.Q = utils.load(model_path) + + def reset(self): + self.points = 0 + self.s = None + self.a = None + + def f(self,u,n): + if n < self.Ne: + return 1 + return u + + def R(self,points,dead): + if dead: + return -1 + elif points > self.points: + return 1 + return -0.1 + + def get_state(self, state): + # [adjoining_wall_x, adjoining_wall_y] + adjoining_wall_x = int(state[0] == utils.WALL_SIZE) + 2 * int(state[0] == utils.DISPLAY_SIZE - utils.WALL_SIZE) + adjoining_wall_y = int(state[1] == utils.WALL_SIZE) + 2 * int(state[1] == utils.DISPLAY_SIZE - utils.WALL_SIZE) + # [food_dir_x, food_dir_y] + food_dir_x = 1 + int(state[0] < state[3]) - int(state[0] == state[3]) + food_dir_y = 1 + int(state[1] < state[4]) - int(state[1] == state[4]) + # [adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right] + adjoining_body = [(state[0] - body_state[0], state[1] - body_state[1]) for body_state in state[2]] + adjoining_body_top = int([0, utils.GRID_SIZE] in adjoining_body) + adjoining_body_bottom = int([0, -utils.GRID_SIZE] in adjoining_body) + adjoining_body_left = int([utils.GRID_SIZE, 0] in adjoining_body) + adjoining_body_right = int([-utils.GRID_SIZE, 0] in adjoining_body) + return adjoining_wall_x, adjoining_wall_y, food_dir_x, food_dir_y, adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right + + + def update_tables(self, _state, points, dead): + if self.s: + maxq = max(self.Q[_state]) + reward = self.R(points,dead) + alpha = self.C / (self.C + self.N[self.s][self.a]) + self.Q[self.s][self.a] += alpha * (reward + self.gamma * maxq - self.Q[self.s][self.a]) + self.N[self.s][self.a] += 1.0 + + def act(self, state, points, dead): + ''' + :param state: a list of [snake_head_x, snake_head_y, snake_body, food_x, food_y] from environment. + :param points: float, the current points from environment + :param dead: boolean, if the snake is dead + :return: the index of action. 0,1,2,3 indicates up,down,left,right separately + TODO: write your function here. + Return the index of action the snake needs to take, according to the state and points known from environment. + Tips: you need to discretize the state to the state space defined on the webpage first. + (Note that [adjoining_wall_x=0, adjoining_wall_y=0] is also the case when snake runs out of the 480x480 board) + ''' + + _state = self.get_state(state) + Qs = self.Q[_state][:] + + if self._train: + self.update_tables(_state, points, dead) + if dead: + self.reset() + return + Ns = self.N[_state] + Fs = [self.f(Qs[a], Ns[a]) for a in self.actions] + action = np.argmax(Fs) + self.s = _state + self.a = action + else: + if dead: + self.reset() + return + action = np.argmax(Qs) + + self.points = points + return action diff --git a/codes/snake/checkpoint.npy b/codes/snake/checkpoint.npy new file mode 100644 index 0000000..591d49e Binary files /dev/null and b/codes/snake/checkpoint.npy differ diff --git a/codes/snake/checkpoint1.npy b/codes/snake/checkpoint1.npy new file mode 100644 index 0000000..84b54ca Binary files /dev/null and b/codes/snake/checkpoint1.npy differ diff --git a/codes/snake/checkpoint2.npy b/codes/snake/checkpoint2.npy new file mode 100644 index 0000000..4614eb7 Binary files /dev/null and b/codes/snake/checkpoint2.npy differ diff --git a/codes/snake/checkpoint3.npy b/codes/snake/checkpoint3.npy new file mode 100644 index 0000000..20085c9 Binary files /dev/null and b/codes/snake/checkpoint3.npy differ diff --git a/codes/snake/example_assignment_and_report2.pdf b/codes/snake/example_assignment_and_report2.pdf new file mode 100644 index 0000000..84008c0 Binary files /dev/null and b/codes/snake/example_assignment_and_report2.pdf differ diff --git a/codes/snake/main.py b/codes/snake/main.py new file mode 100644 index 0000000..c407491 --- /dev/null +++ b/codes/snake/main.py @@ -0,0 +1,183 @@ +import pygame +from pygame.locals import * +import argparse + +from agent import Agent +from snake_env import SnakeEnv +import utils +import time + +class Application: + def __init__(self, args): + self.args = args + self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y) + self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma) + + def execute(self): + if not self.args.human: + if self.args.train_eps != 0: + self.train() + self.test() + self.show_games() + + def train(self): + print("Train Phase:") + self.agent.train() + window = self.args.window + self.points_results = [] + first_eat = True + start = time.time() + + for game in range(1, self.args.train_eps + 1): + state = self.env.get_state() + dead = False + action = self.agent.act(state, 0, dead) + while not dead: + state, points, dead = self.env.step(action) + + # For debug convenience, you can check if your Q-table mathches ours for given setting of parameters + # (see Debug Convenience part on homework 4 web page) + if first_eat and points == 1: + self.agent.save_model(utils.CHECKPOINT) + first_eat = False + + action = self.agent.act(state, points, dead) + + + points = self.env.get_points() + self.points_results.append(points) + if game % self.args.window == 0: + print( + "Games:", len(self.points_results) - window, "-", len(self.points_results), + "Points (Average:", sum(self.points_results[-window:])/window, + "Max:", max(self.points_results[-window:]), + "Min:", min(self.points_results[-window:]),")", + ) + self.env.reset() + print("Training takes", time.time() - start, "seconds") + self.agent.save_model(self.args.model_name) + + def test(self): + print("Test Phase:") + self.agent.eval() + self.agent.load_model(self.args.model_name) + points_results = [] + start = time.time() + + for game in range(1, self.args.test_eps + 1): + state = self.env.get_state() + dead = False + action = self.agent.act(state, 0, dead) + while not dead: + state, points, dead = self.env.step(action) + action = self.agent.act(state, points, dead) + points = self.env.get_points() + points_results.append(points) + self.env.reset() + + print("Testing takes", time.time() - start, "seconds") + print("Number of Games:", len(points_results)) + print("Average Points:", sum(points_results)/len(points_results)) + print("Max Points:", max(points_results)) + print("Min Points:", min(points_results)) + + def show_games(self): + print("Display Games") + self.env.display() + pygame.event.pump() + self.agent.eval() + points_results = [] + end = False + for game in range(1, self.args.show_eps + 1): + state = self.env.get_state() + dead = False + action = self.agent.act(state, 0, dead) + count = 0 + while not dead: + count +=1 + pygame.event.pump() + keys = pygame.key.get_pressed() + if keys[K_ESCAPE] or self.check_quit(): + end = True + break + state, points, dead = self.env.step(action) + # Qlearning agent + if not self.args.human: + action = self.agent.act(state, points, dead) + # for human player + else: + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + if event.key == pygame.K_UP: + action = 2 + elif event.key == pygame.K_DOWN: + action = 3 + elif event.key == pygame.K_LEFT: + action = 1 + elif event.key == pygame.K_RIGHT: + action = 0 + if end: + break + self.env.reset() + points_results.append(points) + print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points) + if len(points_results) == 0: + return + print("Average Points:", sum(points_results)/len(points_results)) + + def check_quit(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return True + return False + + +def main(): + parser = argparse.ArgumentParser(description='CS440 MP4 Snake') + + parser.add_argument('--human', default = False, action="store_true", + help='making the game human playable - default False') + + parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy", + help='name of model to save if training or to load if evaluating - default q_agent') + + parser.add_argument('--train_episodes', dest="train_eps", type=int, default=10000, + help='number of training episodes - default 10000') + + parser.add_argument('--test_episodes', dest="test_eps", type=int, default=1000, + help='number of testing episodes - default 1000') + + parser.add_argument('--show_episodes', dest="show_eps", type=int, default=10, + help='number of displayed episodes - default 10') + + parser.add_argument('--window', dest="window", type=int, default=100, + help='number of episodes to keep running stats for during training - default 100') + + parser.add_argument('--Ne', dest="Ne", type=int, default=40, + help='the Ne parameter used in exploration function - default 40') + + parser.add_argument('--C', dest="C", type=int, default=40, + help='the C parameter used in learning rate - default 40') + + parser.add_argument('--gamma', dest="gamma", type=float, default=0.2, + help='the gamma paramter used in learning rate - default 0.7') + + parser.add_argument('--snake_head_x', dest="snake_head_x", type=int, default=200, + help='initialized x position of snake head - default 200') + + parser.add_argument('--snake_head_y', dest="snake_head_y", type=int, default=200, + help='initialized y position of snake head - default 200') + + parser.add_argument('--food_x', dest="food_x", type=int, default=80, + help='initialized x position of food - default 80') + + parser.add_argument('--food_y', dest="food_y", type=int, default=80, + help='initialized y position of food - default 80') + + + args = parser.parse_args() + app = Application(args) + app.execute() + +if __name__ == "__main__": + main() diff --git a/codes/snake/q_agent.npy b/codes/snake/q_agent.npy new file mode 100644 index 0000000..75ef415 Binary files /dev/null and b/codes/snake/q_agent.npy differ diff --git a/codes/snake/snake_env.py b/codes/snake/snake_env.py new file mode 100644 index 0000000..aa22809 --- /dev/null +++ b/codes/snake/snake_env.py @@ -0,0 +1,204 @@ +import random +import pygame +import utils + +class SnakeEnv: + def __init__(self, snake_head_x, snake_head_y, food_x, food_y): + self.game = Snake(snake_head_x, snake_head_y, food_x, food_y) + self.render = False + + def get_actions(self): + return self.game.get_actions() + + def reset(self): + return self.game.reset() + + def get_points(self): + return self.game.get_points() + + def get_state(self): + return self.game.get_state() + + def step(self, action): + state, points, dead = self.game.step(action) + if self.render: + self.draw(state, points, dead) + return state, points, dead + + def draw(self, state, points, dead): + snake_head_x, snake_head_y, snake_body, food_x, food_y = state + self.display.fill(utils.BLUE) + pygame.draw.rect( self.display, utils.BLACK, + [ + utils.GRID_SIZE, + utils.GRID_SIZE, + utils.DISPLAY_SIZE - utils.GRID_SIZE * 2, + utils.DISPLAY_SIZE - utils.GRID_SIZE * 2 + ]) + + # draw snake head + pygame.draw.rect( + self.display, + utils.GREEN, + [ + snake_head_x, + snake_head_y, + utils.GRID_SIZE, + utils.GRID_SIZE + ], + 3 + ) + # draw snake body + for seg in snake_body: + pygame.draw.rect( + self.display, + utils.GREEN, + [ + seg[0], + seg[1], + utils.GRID_SIZE, + utils.GRID_SIZE, + ], + 1 + ) + # draw food + pygame.draw.rect( + self.display, + utils.RED, + [ + food_x, + food_y, + utils.GRID_SIZE, + utils.GRID_SIZE + ] + ) + + text_surface = self.font.render("Points: " + str(points), True, utils.BLACK) + text_rect = text_surface.get_rect() + text_rect.center = ((280),(25)) + self.display.blit(text_surface, text_rect) + pygame.display.flip() + if dead: + # slow clock if dead + self.clock.tick(1) + else: + self.clock.tick(5) + + return + + + def display(self): + pygame.init() + pygame.display.set_caption('MP4: Snake') + self.clock = pygame.time.Clock() + pygame.font.init() + + self.font = pygame.font.Font(pygame.font.get_default_font(), 15) + self.display = pygame.display.set_mode((utils.DISPLAY_SIZE, utils.DISPLAY_SIZE), pygame.HWSURFACE) + self.draw(self.game.get_state(), self.game.get_points(), False) + self.render = True + +class Snake: + def __init__(self, snake_head_x, snake_head_y, food_x, food_y): + self.init_snake_head_x = snake_head_x + self.init_snake_head_y = snake_head_y + self.init_food_x = food_x + self.init_food_y = food_y + self.reset() + + def reset(self): + self.points = 0 + self.snake_head_x = self.init_snake_head_x + self.snake_head_y = self.init_snake_head_y + self.snake_body = [] + self.food_x = self.init_food_x + self.food_y = self.init_food_y + + def get_points(self): + return self.points + + def get_actions(self): + return [0, 1, 2, 3] + + def get_state(self): + return [ + self.snake_head_x, + self.snake_head_y, + self.snake_body, + self.food_x, + self.food_y + ] + + def move(self, action): + delta_x = delta_y = 0 + if action == 0: + delta_x = utils.GRID_SIZE + elif action == 1: + delta_x = - utils.GRID_SIZE + elif action == 2: + delta_y = - utils.GRID_SIZE + elif action == 3: + delta_y = utils.GRID_SIZE + + old_body_head = None + if len(self.snake_body) == 1: + old_body_head = self.snake_body[0] + self.snake_body.append((self.snake_head_x, self.snake_head_y)) + self.snake_head_x += delta_x + self.snake_head_y += delta_y + + if len(self.snake_body) > self.points: + del(self.snake_body[0]) + + self.handle_eatfood() + + # colliding with the snake body or going backwards while its body length + # greater than 1 + if len(self.snake_body) >= 1: + for seg in self.snake_body: + if self.snake_head_x == seg[0] and self.snake_head_y == seg[1]: + return True + + # moving towards body direction, not allowing snake to go backwards while + # its body length is 1 + if len(self.snake_body) == 1: + if old_body_head == (self.snake_head_x, self.snake_head_y): + return True + + # collide with the wall + if (self.snake_head_x < utils.GRID_SIZE or self.snake_head_y < utils.GRID_SIZE or + self.snake_head_x + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE or self.snake_head_y + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE): + return True + + return False + + def step(self, action): + is_dead = self.move(action) + return self.get_state(), self.get_points(), is_dead + + def handle_eatfood(self): + if (self.snake_head_x == self.food_x) and (self.snake_head_y == self.food_y): + self.random_food() + self.points += 1 + + + def random_food(self): + max_x = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE) + max_y = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE) + + self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE + self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE + + while self.check_food_on_snake(): + self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE + self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE + + def check_food_on_snake(self): + if self.food_x == self.snake_head_x and self.food_y == self.snake_head_y: + return True + for seg in self.snake_body: + if self.food_x == seg[0] and self.food_y == seg[1]: + return True + return False + + diff --git a/codes/snake/utils.py b/codes/snake/utils.py new file mode 100644 index 0000000..01c9b00 --- /dev/null +++ b/codes/snake/utils.py @@ -0,0 +1,55 @@ +import numpy as np +DISPLAY_SIZE = 560 +GRID_SIZE = 40 +WALL_SIZE = 40 +WHITE = (255, 255, 255) +RED = (255, 0, 0) +BLUE = (72, 61, 139) +BLACK = (0, 0, 0) +GREEN = (0, 255, 0) + +NUM_ADJOINING_WALL_X_STATES=3 +NUM_ADJOINING_WALL_Y_STATES=3 +NUM_FOOD_DIR_X=3 +NUM_FOOD_DIR_Y=3 +NUM_ADJOINING_BODY_TOP_STATES=2 +NUM_ADJOINING_BODY_BOTTOM_STATES=2 +NUM_ADJOINING_BODY_LEFT_STATES=2 +NUM_ADJOINING_BODY_RIGHT_STATES=2 +NUM_ACTIONS = 4 + +CHECKPOINT = 'checkpoint.npy' + +def create_q_table(): + return np.zeros((NUM_ADJOINING_WALL_X_STATES, NUM_ADJOINING_WALL_Y_STATES, NUM_FOOD_DIR_X, NUM_FOOD_DIR_Y, + NUM_ADJOINING_BODY_TOP_STATES, NUM_ADJOINING_BODY_BOTTOM_STATES, NUM_ADJOINING_BODY_LEFT_STATES, + NUM_ADJOINING_BODY_RIGHT_STATES, NUM_ACTIONS)) + +def sanity_check(arr): + if (type(arr) is np.ndarray and + arr.shape==(NUM_ADJOINING_WALL_X_STATES, NUM_ADJOINING_WALL_Y_STATES, NUM_FOOD_DIR_X, NUM_FOOD_DIR_Y, + NUM_ADJOINING_BODY_TOP_STATES, NUM_ADJOINING_BODY_BOTTOM_STATES, NUM_ADJOINING_BODY_LEFT_STATES, + NUM_ADJOINING_BODY_RIGHT_STATES,NUM_ACTIONS)): + return True + else: + return False + +def save(filename, arr): + if sanity_check(arr): + np.save(filename,arr) + return True + else: + print("Failed to save model") + return False + +def load(filename): + try: + arr = np.load(filename) + if sanity_check(arr): + print("Loaded model successfully") + return arr + print("Model loaded is not in the required format") + return None + except: + print("Filename doesnt exist") + return None \ No newline at end of file