This commit is contained in:
JohnJim0816
2021-03-23 16:10:11 +08:00
parent d4690c2058
commit bf0f2990cf
198 changed files with 1668 additions and 1545 deletions

View File

@@ -1,8 +1,8 @@
# 使用 Q-learning 实现贪吃蛇
# 贪吃蛇
贪吃蛇是一个起源于1976年的街机游戏 Blockade玩家控制蛇上下左右吃到食物并将身体增长吃到食物后移动速度逐渐加快直到碰到墙体或者蛇的身体算游戏结束。
![image-20200901202636603](assets/image-20200901202636603.png)
![image-20200901202636603](img/image-20200901202636603.png)
如图本次任务整个游戏版面大小为560X560绿色部分就是我们的智能体贪吃蛇红色方块就是食物墙位于四周一旦食物被吃掉会在下一个随机位置刷出新的食物。蛇的每一节以及食物的大小为40X40除开墙体(厚度也为40)蛇可以活动的范围为480X480也就是12X12的栅格。环境的状态等信息如下
@@ -34,8 +34,5 @@
* reward如果吃到食物给一个+1的reward如果蛇没了就-1其他情况给-0.1的reward
## 任务要求
设计一个Q-learning agent用于学习snake游戏并绘制reward以及滑动平均后的reward随episode的变化曲线图并记录超参数写成报告。
[参考代码](https://github.com/datawhalechina/leedeeprl-notes/tree/master/codes/snake)

View File

@@ -64,7 +64,7 @@ class Agent:
return adjoining_wall_x, adjoining_wall_y, food_dir_x, food_dir_y, adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right
def update_tables(self, _state, points, dead):
def update(self, _state, points, dead):
if self.s:
maxq = max(self.Q[_state])
reward = self.R(points,dead)
@@ -72,7 +72,7 @@ class Agent:
self.Q[self.s][self.a] += alpha * (reward + self.gamma * maxq - self.Q[self.s][self.a])
self.N[self.s][self.a] += 1.0
def act(self, state, points, dead):
def choose_action(self, state, points, dead):
'''
:param state: a list of [snake_head_x, snake_head_y, snake_body, food_x, food_y] from environment.
:param points: float, the current points from environment
@@ -88,7 +88,7 @@ class Agent:
Qs = self.Q[_state][:]
if self._train:
self.update_tables(_state, points, dead)
self.update(_state, points, dead)
if dead:
self.reset()
return

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

Binary file not shown.

View File

@@ -7,135 +7,10 @@ from snake_env import SnakeEnv
import utils
import time
class Application:
def __init__(self, args):
self.args = args
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
def execute(self):
if not self.args.human:
if self.args.train_eps != 0:
self.train()
self.test()
self.show_games()
def train(self):
print("Train Phase:")
self.agent.train()
window = self.args.window
self.points_results = []
first_eat = True
start = time.time()
for game in range(1, self.args.train_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.act(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
# (see Debug Convenience part on homework 4 web page)
if first_eat and points == 1:
self.agent.save_model(utils.CHECKPOINT)
first_eat = False
action = self.agent.act(state, points, dead)
points = self.env.get_points()
self.points_results.append(points)
if game % self.args.window == 0:
print(
"Games:", len(self.points_results) - window, "-", len(self.points_results),
"Points (Average:", sum(self.points_results[-window:])/window,
"Max:", max(self.points_results[-window:]),
"Min:", min(self.points_results[-window:]),")",
)
self.env.reset()
print("Training takes", time.time() - start, "seconds")
self.agent.save_model(self.args.model_name)
def test(self):
print("Test Phase:")
self.agent.eval()
self.agent.load_model(self.args.model_name)
points_results = []
start = time.time()
for game in range(1, self.args.test_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.act(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
action = self.agent.act(state, points, dead)
points = self.env.get_points()
points_results.append(points)
self.env.reset()
print("Testing takes", time.time() - start, "seconds")
print("Number of Games:", len(points_results))
print("Average Points:", sum(points_results)/len(points_results))
print("Max Points:", max(points_results))
print("Min Points:", min(points_results))
def show_games(self):
print("Display Games")
self.env.display()
pygame.event.pump()
self.agent.eval()
points_results = []
end = False
for game in range(1, self.args.show_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.act(state, 0, dead)
count = 0
while not dead:
count +=1
pygame.event.pump()
keys = pygame.key.get_pressed()
if keys[K_ESCAPE] or self.check_quit():
end = True
break
state, points, dead = self.env.step(action)
# Qlearning agent
if not self.args.human:
action = self.agent.act(state, points, dead)
# for human player
else:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_UP:
action = 2
elif event.key == pygame.K_DOWN:
action = 3
elif event.key == pygame.K_LEFT:
action = 1
elif event.key == pygame.K_RIGHT:
action = 0
if end:
break
self.env.reset()
points_results.append(points)
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
if len(points_results) == 0:
return
print("Average Points:", sum(points_results)/len(points_results))
def check_quit(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
return True
return False
def main():
def get_args():
parser = argparse.ArgumentParser(description='CS440 MP4 Snake')
parser.add_argument('--human', default = False, action="store_true",
parser.add_argument('--human', default = True, action="store_true",
help='making the game human playable - default False')
parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy",
@@ -173,10 +48,137 @@ def main():
parser.add_argument('--food_y', dest="food_y", type=int, default=80,
help='initialized y position of food - default 80')
cfg = parser.parse_args()
return cfg
class Application:
def __init__(self, args):
self.args = args
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
def execute(self):
if not self.args.human:
if self.args.train_eps != 0:
self.train()
self.eval()
self.show_games()
args = parser.parse_args()
app = Application(args)
def train(self):
print("Train Phase:")
self.agent.train()
window = self.args.window
self.points_results = []
first_eat = True
start = time.time()
for game in range(1, self.args.train_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.choose_action(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
# (see Debug Convenience part on homework 4 web page)
if first_eat and points == 1:
self.agent.save_model(utils.CHECKPOINT)
first_eat = False
action = self.agent.choose_action(state, points, dead)
points = self.env.get_points()
self.points_results.append(points)
if game % self.args.window == 0:
print(
"Games:", len(self.points_results) - window, "-", len(self.points_results),
"Points (Average:", sum(self.points_results[-window:])/window,
"Max:", max(self.points_results[-window:]),
"Min:", min(self.points_results[-window:]),")",
)
self.env.reset()
print("Training takes", time.time() - start, "seconds")
self.agent.save_model(self.args.model_name)
def eval(self):
print("Evaling Phase:")
self.agent.eval()
self.agent.load_model(self.args.model_name)
points_results = []
start = time.time()
for game in range(1, self.args.test_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.choose_action(state, 0, dead)
while not dead:
state, points, dead = self.env.step(action)
action = self.agent.choose_action(state, points, dead)
points = self.env.get_points()
points_results.append(points)
self.env.reset()
print("Testing takes", time.time() - start, "seconds")
print("Number of Games:", len(points_results))
print("Average Points:", sum(points_results)/len(points_results))
print("Max Points:", max(points_results))
print("Min Points:", min(points_results))
def show_games(self):
print("Display Games")
self.env.display()
pygame.event.pump()
self.agent.eval()
points_results = []
end = False
for game in range(1, self.args.show_eps + 1):
state = self.env.get_state()
dead = False
action = self.agent.choose_action(state, 0, dead)
count = 0
while not dead:
count +=1
pygame.event.pump()
keys = pygame.key.get_pressed()
if keys[K_ESCAPE] or self.check_quit():
end = True
break
state, points, dead = self.env.step(action)
# Qlearning agent
if not self.args.human:
action = self.agent.choose_action(state, points, dead)
# for human player
else:
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_UP:
action = 2
elif event.key == pygame.K_DOWN:
action = 3
elif event.key == pygame.K_LEFT:
action = 1
elif event.key == pygame.K_RIGHT:
action = 0
if end:
break
self.env.reset()
points_results.append(points)
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
if len(points_results) == 0:
return
print("Average Points:", sum(points_results)/len(points_results))
def check_quit(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
return True
return False
def main():
cfg = get_args()
app = Application(cfg)
app.execute()
if __name__ == "__main__":

View File

@@ -23,6 +23,7 @@ class SnakeEnv:
state, points, dead = self.game.step(action)
if self.render:
self.draw(state, points, dead)
# return state, reward, done
return state, points, dead
def draw(self, state, points, dead):
@@ -99,22 +100,16 @@ class SnakeEnv:
self.render = True
class Snake:
''' 定义贪吃蛇的类
'''
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
# 初始化蛇头的位置
self.init_snake_head_x, self.init_snake_head_y = snake_head_x, snake_head_y
# 初始化食物的位置
self.init_food_x, self.init_food_y = food_x, food_y
self.init_snake_head_x,self.init_snake_head_y = snake_head_x,snake_head_y # 蛇头初始位置
self.init_food_x, self.init_food_y = food_x, food_y # 食物初始位置
self.reset()
def reset(self):
self.points = 0
self.snake_head_x = self.init_snake_head_x
self.snake_head_y = self.init_snake_head_y
self.snake_body = []
self.food_x = self.init_food_x
self.food_y = self.init_food_y
self.snake_head_x, self.snake_head_y = self.init_snake_head_x, self.init_snake_head_y
self.food_x, self.food_y = self.init_food_x, self.init_food_y
self.snake_body = [] # 蛇身的位置集合
def get_points(self):
return self.points
@@ -132,8 +127,10 @@ class Snake:
]
def move(self, action):
'''根据action指令移动蛇头并返回是否撞死
'''
delta_x = delta_y = 0
if action == 0:
if action == 0: # 上
delta_x = utils.GRID_SIZE
elif action == 1:
delta_x = - utils.GRID_SIZE
@@ -141,33 +138,31 @@ class Snake:
delta_y = - utils.GRID_SIZE
elif action == 3:
delta_y = utils.GRID_SIZE
old_body_head = None
if len(self.snake_body) == 1:
old_body_head = self.snake_body[0]
self.snake_body.append((self.snake_head_x, self.snake_head_y))
self.snake_head_x += delta_x
self.snake_head_y += delta_y
if len(self.snake_body) > self.points:
if len(self.snake_body) > self.points: # 说明没有吃到食物
del(self.snake_body[0])
self.handle_eatfood()
# colliding with the snake body or going backwards while its body length
# greater than 1
# 蛇长大于1时蛇头与蛇身任一位置重叠则看作蛇与自身相撞
if len(self.snake_body) >= 1:
for seg in self.snake_body:
if self.snake_head_x == seg[0] and self.snake_head_y == seg[1]:
return True
# moving towards body direction, not allowing snake to go backwards while
# its body length is 1
# 蛇长为1时如果蛇头与之前的位置重复则看作蛇与自身相撞
if len(self.snake_body) == 1:
if old_body_head == (self.snake_head_x, self.snake_head_y):
return True
# collide with the wall
# 蛇头是否撞墙
if (self.snake_head_x < utils.GRID_SIZE or self.snake_head_y < utils.GRID_SIZE or
self.snake_head_x + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE or self.snake_head_y + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE):
return True
@@ -183,15 +178,16 @@ class Snake:
self.random_food()
self.points += 1
def random_food(self):
'''生成随机位置的食物
'''
max_x = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
max_y = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
while self.check_food_on_snake():
while self.check_food_on_snake(): # 食物不能生成在蛇身上
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE