update
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# 使用 Q-learning 实现贪吃蛇
|
||||
# 贪吃蛇
|
||||
|
||||
贪吃蛇是一个起源于1976年的街机游戏 Blockade,玩家控制蛇上下左右吃到食物并将身体增长,吃到食物后移动速度逐渐加快,直到碰到墙体或者蛇的身体算游戏结束。
|
||||
|
||||

|
||||

|
||||
|
||||
如图,本次任务整个游戏版面大小为560X560,绿色部分就是我们的智能体贪吃蛇,红色方块就是食物,墙位于四周,一旦食物被吃掉,会在下一个随机位置刷出新的食物。蛇的每一节以及食物的大小为40X40,除开墙体(厚度也为40),蛇可以活动的范围为480X480,也就是12X12的栅格。环境的状态等信息如下:
|
||||
|
||||
@@ -34,8 +34,5 @@
|
||||
|
||||
* reward:如果吃到食物给一个+1的reward,如果蛇没了就-1,其他情况给-0.1的reward
|
||||
|
||||
## 任务要求
|
||||
|
||||
设计一个Q-learning agent用于学习snake游戏,并绘制reward以及滑动平均后的reward随episode的变化曲线图并记录超参数写成报告。
|
||||
|
||||
[参考代码](https://github.com/datawhalechina/leedeeprl-notes/tree/master/codes/snake)
|
||||
@@ -64,7 +64,7 @@ class Agent:
|
||||
return adjoining_wall_x, adjoining_wall_y, food_dir_x, food_dir_y, adjoining_body_top, adjoining_body_bottom, adjoining_body_left, adjoining_body_right
|
||||
|
||||
|
||||
def update_tables(self, _state, points, dead):
|
||||
def update(self, _state, points, dead):
|
||||
if self.s:
|
||||
maxq = max(self.Q[_state])
|
||||
reward = self.R(points,dead)
|
||||
@@ -72,7 +72,7 @@ class Agent:
|
||||
self.Q[self.s][self.a] += alpha * (reward + self.gamma * maxq - self.Q[self.s][self.a])
|
||||
self.N[self.s][self.a] += 1.0
|
||||
|
||||
def act(self, state, points, dead):
|
||||
def choose_action(self, state, points, dead):
|
||||
'''
|
||||
:param state: a list of [snake_head_x, snake_head_y, snake_body, food_x, food_y] from environment.
|
||||
:param points: float, the current points from environment
|
||||
@@ -88,7 +88,7 @@ class Agent:
|
||||
Qs = self.Q[_state][:]
|
||||
|
||||
if self._train:
|
||||
self.update_tables(_state, points, dead)
|
||||
self.update(_state, points, dead)
|
||||
if dead:
|
||||
self.reset()
|
||||
return
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 32 KiB |
Binary file not shown.
BIN
codes/snake/example_assignment_and_report2.pdf
Normal file
BIN
codes/snake/example_assignment_and_report2.pdf
Normal file
Binary file not shown.
@@ -7,135 +7,10 @@ from snake_env import SnakeEnv
|
||||
import utils
|
||||
import time
|
||||
|
||||
class Application:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
|
||||
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
|
||||
|
||||
def execute(self):
|
||||
if not self.args.human:
|
||||
if self.args.train_eps != 0:
|
||||
self.train()
|
||||
self.test()
|
||||
self.show_games()
|
||||
|
||||
def train(self):
|
||||
print("Train Phase:")
|
||||
self.agent.train()
|
||||
window = self.args.window
|
||||
self.points_results = []
|
||||
first_eat = True
|
||||
start = time.time()
|
||||
|
||||
for game in range(1, self.args.train_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.act(state, 0, dead)
|
||||
while not dead:
|
||||
state, points, dead = self.env.step(action)
|
||||
|
||||
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
|
||||
# (see Debug Convenience part on homework 4 web page)
|
||||
if first_eat and points == 1:
|
||||
self.agent.save_model(utils.CHECKPOINT)
|
||||
first_eat = False
|
||||
|
||||
action = self.agent.act(state, points, dead)
|
||||
|
||||
|
||||
points = self.env.get_points()
|
||||
self.points_results.append(points)
|
||||
if game % self.args.window == 0:
|
||||
print(
|
||||
"Games:", len(self.points_results) - window, "-", len(self.points_results),
|
||||
"Points (Average:", sum(self.points_results[-window:])/window,
|
||||
"Max:", max(self.points_results[-window:]),
|
||||
"Min:", min(self.points_results[-window:]),")",
|
||||
)
|
||||
self.env.reset()
|
||||
print("Training takes", time.time() - start, "seconds")
|
||||
self.agent.save_model(self.args.model_name)
|
||||
|
||||
def test(self):
|
||||
print("Test Phase:")
|
||||
self.agent.eval()
|
||||
self.agent.load_model(self.args.model_name)
|
||||
points_results = []
|
||||
start = time.time()
|
||||
|
||||
for game in range(1, self.args.test_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.act(state, 0, dead)
|
||||
while not dead:
|
||||
state, points, dead = self.env.step(action)
|
||||
action = self.agent.act(state, points, dead)
|
||||
points = self.env.get_points()
|
||||
points_results.append(points)
|
||||
self.env.reset()
|
||||
|
||||
print("Testing takes", time.time() - start, "seconds")
|
||||
print("Number of Games:", len(points_results))
|
||||
print("Average Points:", sum(points_results)/len(points_results))
|
||||
print("Max Points:", max(points_results))
|
||||
print("Min Points:", min(points_results))
|
||||
|
||||
def show_games(self):
|
||||
print("Display Games")
|
||||
self.env.display()
|
||||
pygame.event.pump()
|
||||
self.agent.eval()
|
||||
points_results = []
|
||||
end = False
|
||||
for game in range(1, self.args.show_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.act(state, 0, dead)
|
||||
count = 0
|
||||
while not dead:
|
||||
count +=1
|
||||
pygame.event.pump()
|
||||
keys = pygame.key.get_pressed()
|
||||
if keys[K_ESCAPE] or self.check_quit():
|
||||
end = True
|
||||
break
|
||||
state, points, dead = self.env.step(action)
|
||||
# Qlearning agent
|
||||
if not self.args.human:
|
||||
action = self.agent.act(state, points, dead)
|
||||
# for human player
|
||||
else:
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_UP:
|
||||
action = 2
|
||||
elif event.key == pygame.K_DOWN:
|
||||
action = 3
|
||||
elif event.key == pygame.K_LEFT:
|
||||
action = 1
|
||||
elif event.key == pygame.K_RIGHT:
|
||||
action = 0
|
||||
if end:
|
||||
break
|
||||
self.env.reset()
|
||||
points_results.append(points)
|
||||
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
|
||||
if len(points_results) == 0:
|
||||
return
|
||||
print("Average Points:", sum(points_results)/len(points_results))
|
||||
|
||||
def check_quit(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='CS440 MP4 Snake')
|
||||
|
||||
parser.add_argument('--human', default = False, action="store_true",
|
||||
parser.add_argument('--human', default = True, action="store_true",
|
||||
help='making the game human playable - default False')
|
||||
|
||||
parser.add_argument('--model_name', dest="model_name", type=str, default="checkpoint3.npy",
|
||||
@@ -173,10 +48,137 @@ def main():
|
||||
|
||||
parser.add_argument('--food_y', dest="food_y", type=int, default=80,
|
||||
help='initialized y position of food - default 80')
|
||||
cfg = parser.parse_args()
|
||||
return cfg
|
||||
|
||||
class Application:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
|
||||
self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
|
||||
|
||||
def execute(self):
|
||||
if not self.args.human:
|
||||
if self.args.train_eps != 0:
|
||||
self.train()
|
||||
self.eval()
|
||||
self.show_games()
|
||||
|
||||
args = parser.parse_args()
|
||||
app = Application(args)
|
||||
def train(self):
|
||||
print("Train Phase:")
|
||||
self.agent.train()
|
||||
window = self.args.window
|
||||
self.points_results = []
|
||||
first_eat = True
|
||||
start = time.time()
|
||||
|
||||
for game in range(1, self.args.train_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.choose_action(state, 0, dead)
|
||||
while not dead:
|
||||
state, points, dead = self.env.step(action)
|
||||
|
||||
# For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
|
||||
# (see Debug Convenience part on homework 4 web page)
|
||||
if first_eat and points == 1:
|
||||
self.agent.save_model(utils.CHECKPOINT)
|
||||
first_eat = False
|
||||
|
||||
action = self.agent.choose_action(state, points, dead)
|
||||
|
||||
|
||||
points = self.env.get_points()
|
||||
self.points_results.append(points)
|
||||
if game % self.args.window == 0:
|
||||
print(
|
||||
"Games:", len(self.points_results) - window, "-", len(self.points_results),
|
||||
"Points (Average:", sum(self.points_results[-window:])/window,
|
||||
"Max:", max(self.points_results[-window:]),
|
||||
"Min:", min(self.points_results[-window:]),")",
|
||||
)
|
||||
self.env.reset()
|
||||
print("Training takes", time.time() - start, "seconds")
|
||||
self.agent.save_model(self.args.model_name)
|
||||
|
||||
def eval(self):
|
||||
print("Evaling Phase:")
|
||||
self.agent.eval()
|
||||
self.agent.load_model(self.args.model_name)
|
||||
points_results = []
|
||||
start = time.time()
|
||||
|
||||
for game in range(1, self.args.test_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.choose_action(state, 0, dead)
|
||||
while not dead:
|
||||
state, points, dead = self.env.step(action)
|
||||
action = self.agent.choose_action(state, points, dead)
|
||||
points = self.env.get_points()
|
||||
points_results.append(points)
|
||||
self.env.reset()
|
||||
|
||||
print("Testing takes", time.time() - start, "seconds")
|
||||
print("Number of Games:", len(points_results))
|
||||
print("Average Points:", sum(points_results)/len(points_results))
|
||||
print("Max Points:", max(points_results))
|
||||
print("Min Points:", min(points_results))
|
||||
|
||||
def show_games(self):
|
||||
print("Display Games")
|
||||
self.env.display()
|
||||
pygame.event.pump()
|
||||
self.agent.eval()
|
||||
points_results = []
|
||||
end = False
|
||||
for game in range(1, self.args.show_eps + 1):
|
||||
state = self.env.get_state()
|
||||
dead = False
|
||||
action = self.agent.choose_action(state, 0, dead)
|
||||
count = 0
|
||||
while not dead:
|
||||
count +=1
|
||||
pygame.event.pump()
|
||||
keys = pygame.key.get_pressed()
|
||||
if keys[K_ESCAPE] or self.check_quit():
|
||||
end = True
|
||||
break
|
||||
state, points, dead = self.env.step(action)
|
||||
# Qlearning agent
|
||||
if not self.args.human:
|
||||
action = self.agent.choose_action(state, points, dead)
|
||||
# for human player
|
||||
else:
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_UP:
|
||||
action = 2
|
||||
elif event.key == pygame.K_DOWN:
|
||||
action = 3
|
||||
elif event.key == pygame.K_LEFT:
|
||||
action = 1
|
||||
elif event.key == pygame.K_RIGHT:
|
||||
action = 0
|
||||
if end:
|
||||
break
|
||||
self.env.reset()
|
||||
points_results.append(points)
|
||||
print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
|
||||
if len(points_results) == 0:
|
||||
return
|
||||
print("Average Points:", sum(points_results)/len(points_results))
|
||||
|
||||
def check_quit(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
cfg = get_args()
|
||||
app = Application(cfg)
|
||||
app.execute()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -23,6 +23,7 @@ class SnakeEnv:
|
||||
state, points, dead = self.game.step(action)
|
||||
if self.render:
|
||||
self.draw(state, points, dead)
|
||||
# return state, reward, done
|
||||
return state, points, dead
|
||||
|
||||
def draw(self, state, points, dead):
|
||||
@@ -99,22 +100,16 @@ class SnakeEnv:
|
||||
self.render = True
|
||||
|
||||
class Snake:
|
||||
''' 定义贪吃蛇的类
|
||||
'''
|
||||
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
|
||||
# 初始化蛇头的位置
|
||||
self.init_snake_head_x, self.init_snake_head_y = snake_head_x, snake_head_y
|
||||
# 初始化食物的位置
|
||||
self.init_food_x, self.init_food_y = food_x, food_y
|
||||
self.init_snake_head_x,self.init_snake_head_y = snake_head_x,snake_head_y # 蛇头初始位置
|
||||
self.init_food_x, self.init_food_y = food_x, food_y # 食物初始位置
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.points = 0
|
||||
self.snake_head_x = self.init_snake_head_x
|
||||
self.snake_head_y = self.init_snake_head_y
|
||||
self.snake_body = []
|
||||
self.food_x = self.init_food_x
|
||||
self.food_y = self.init_food_y
|
||||
self.snake_head_x, self.snake_head_y = self.init_snake_head_x, self.init_snake_head_y
|
||||
self.food_x, self.food_y = self.init_food_x, self.init_food_y
|
||||
self.snake_body = [] # 蛇身的位置集合
|
||||
|
||||
def get_points(self):
|
||||
return self.points
|
||||
@@ -132,8 +127,10 @@ class Snake:
|
||||
]
|
||||
|
||||
def move(self, action):
|
||||
'''根据action指令移动蛇头,并返回是否撞死
|
||||
'''
|
||||
delta_x = delta_y = 0
|
||||
if action == 0:
|
||||
if action == 0: # 上
|
||||
delta_x = utils.GRID_SIZE
|
||||
elif action == 1:
|
||||
delta_x = - utils.GRID_SIZE
|
||||
@@ -141,33 +138,31 @@ class Snake:
|
||||
delta_y = - utils.GRID_SIZE
|
||||
elif action == 3:
|
||||
delta_y = utils.GRID_SIZE
|
||||
|
||||
old_body_head = None
|
||||
if len(self.snake_body) == 1:
|
||||
old_body_head = self.snake_body[0]
|
||||
|
||||
self.snake_body.append((self.snake_head_x, self.snake_head_y))
|
||||
self.snake_head_x += delta_x
|
||||
self.snake_head_y += delta_y
|
||||
|
||||
if len(self.snake_body) > self.points:
|
||||
if len(self.snake_body) > self.points: # 说明没有吃到食物
|
||||
del(self.snake_body[0])
|
||||
|
||||
self.handle_eatfood()
|
||||
|
||||
# colliding with the snake body or going backwards while its body length
|
||||
# greater than 1
|
||||
# 蛇长大于1时,蛇头与蛇身任一位置重叠则看作蛇与自身相撞
|
||||
if len(self.snake_body) >= 1:
|
||||
for seg in self.snake_body:
|
||||
if self.snake_head_x == seg[0] and self.snake_head_y == seg[1]:
|
||||
return True
|
||||
|
||||
# moving towards body direction, not allowing snake to go backwards while
|
||||
# its body length is 1
|
||||
# 蛇长为1时,如果蛇头与之前的位置重复则看作蛇与自身相撞
|
||||
if len(self.snake_body) == 1:
|
||||
if old_body_head == (self.snake_head_x, self.snake_head_y):
|
||||
return True
|
||||
|
||||
# collide with the wall
|
||||
# 蛇头是否撞墙
|
||||
if (self.snake_head_x < utils.GRID_SIZE or self.snake_head_y < utils.GRID_SIZE or
|
||||
self.snake_head_x + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE or self.snake_head_y + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE):
|
||||
return True
|
||||
@@ -183,15 +178,16 @@ class Snake:
|
||||
self.random_food()
|
||||
self.points += 1
|
||||
|
||||
|
||||
def random_food(self):
|
||||
'''生成随机位置的食物
|
||||
'''
|
||||
max_x = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
|
||||
max_y = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
|
||||
|
||||
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
|
||||
while self.check_food_on_snake():
|
||||
while self.check_food_on_snake(): # 食物不能生成在蛇身上
|
||||
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
|
||||
|
||||
|
||||
Reference in New Issue
Block a user