Files
easy-rl/codes/envs/snake/snake_env.py
JohnJim0816 e4690ac89f update
2021-04-16 14:59:23 +08:00

203 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import random
import pygame
import utils
class SnakeEnv:
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
self.game = Snake(snake_head_x, snake_head_y, food_x, food_y)
self.render = False
def get_actions(self):
return self.game.get_actions()
def reset(self):
return self.game.reset()
def get_points(self):
return self.game.get_points()
def get_state(self):
return self.game.get_state()
def step(self, action):
state, points, dead = self.game.step(action)
if self.render:
self.draw(state, points, dead)
# return state, reward, done
return state, points, dead
def draw(self, state, points, dead):
snake_head_x, snake_head_y, snake_body, food_x, food_y = state
self.display.fill(utils.BLUE)
pygame.draw.rect( self.display, utils.BLACK,
[
utils.GRID_SIZE,
utils.GRID_SIZE,
utils.DISPLAY_SIZE - utils.GRID_SIZE * 2,
utils.DISPLAY_SIZE - utils.GRID_SIZE * 2
])
# draw snake head
pygame.draw.rect(
self.display,
utils.GREEN,
[
snake_head_x,
snake_head_y,
utils.GRID_SIZE,
utils.GRID_SIZE
],
3
)
# draw snake body
for seg in snake_body:
pygame.draw.rect(
self.display,
utils.GREEN,
[
seg[0],
seg[1],
utils.GRID_SIZE,
utils.GRID_SIZE,
],
1
)
# draw food
pygame.draw.rect(
self.display,
utils.RED,
[
food_x,
food_y,
utils.GRID_SIZE,
utils.GRID_SIZE
]
)
text_surface = self.font.render("Points: " + str(points), True, utils.BLACK)
text_rect = text_surface.get_rect()
text_rect.center = ((280),(25))
self.display.blit(text_surface, text_rect)
pygame.display.flip()
if dead:
# slow clock if dead
self.clock.tick(1)
else:
self.clock.tick(5)
return
def display(self):
pygame.init()
pygame.display.set_caption('MP4: Snake')
self.clock = pygame.time.Clock()
pygame.font.init()
self.font = pygame.font.Font(pygame.font.get_default_font(), 15)
self.display = pygame.display.set_mode((utils.DISPLAY_SIZE, utils.DISPLAY_SIZE), pygame.HWSURFACE)
self.draw(self.game.get_state(), self.game.get_points(), False)
self.render = True
class Snake:
def __init__(self, snake_head_x, snake_head_y, food_x, food_y):
self.init_snake_head_x,self.init_snake_head_y = snake_head_x,snake_head_y # 蛇头初始位置
self.init_food_x, self.init_food_y = food_x, food_y # 食物初始位置
self.reset()
def reset(self):
self.points = 0
self.snake_head_x, self.snake_head_y = self.init_snake_head_x, self.init_snake_head_y
self.food_x, self.food_y = self.init_food_x, self.init_food_y
self.snake_body = [] # 蛇身的位置集合
def get_points(self):
return self.points
def get_actions(self):
return [0, 1, 2, 3]
def get_state(self):
return [
self.snake_head_x,
self.snake_head_y,
self.snake_body,
self.food_x,
self.food_y
]
def move(self, action):
'''根据action指令移动蛇头并返回是否撞死
'''
delta_x = delta_y = 0
if action == 0: # 上
delta_x = utils.GRID_SIZE
elif action == 1:
delta_x = - utils.GRID_SIZE
elif action == 2:
delta_y = - utils.GRID_SIZE
elif action == 3:
delta_y = utils.GRID_SIZE
old_body_head = None
if len(self.snake_body) == 1:
old_body_head = self.snake_body[0]
self.snake_body.append((self.snake_head_x, self.snake_head_y))
self.snake_head_x += delta_x
self.snake_head_y += delta_y
if len(self.snake_body) > self.points: # 说明没有吃到食物
del(self.snake_body[0])
self.handle_eatfood()
# 蛇长大于1时蛇头与蛇身任一位置重叠则看作蛇与自身相撞
if len(self.snake_body) >= 1:
for seg in self.snake_body:
if self.snake_head_x == seg[0] and self.snake_head_y == seg[1]:
return True
# 蛇长为1时如果蛇头与之前的位置重复则看作蛇与自身相撞
if len(self.snake_body) == 1:
if old_body_head == (self.snake_head_x, self.snake_head_y):
return True
# 蛇头是否撞墙
if (self.snake_head_x < utils.GRID_SIZE or self.snake_head_y < utils.GRID_SIZE or
self.snake_head_x + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE or self.snake_head_y + utils.GRID_SIZE > utils.DISPLAY_SIZE-utils.GRID_SIZE):
return True
return False
def step(self, action):
is_dead = self.move(action)
return self.get_state(), self.get_points(), is_dead
def handle_eatfood(self):
if (self.snake_head_x == self.food_x) and (self.snake_head_y == self.food_y):
self.random_food()
self.points += 1
def random_food(self):
'''生成随机位置的食物
'''
max_x = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
max_y = (utils.DISPLAY_SIZE - utils.WALL_SIZE - utils.GRID_SIZE)
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
while self.check_food_on_snake(): # 食物不能生成在蛇身上
self.food_x = random.randint(utils.WALL_SIZE, max_x)//utils.GRID_SIZE * utils.GRID_SIZE
self.food_y = random.randint(utils.WALL_SIZE, max_y)//utils.GRID_SIZE * utils.GRID_SIZE
def check_food_on_snake(self):
if self.food_x == self.snake_head_x and self.food_y == self.snake_head_y:
return True
for seg in self.snake_body:
if self.food_x == seg[0] and self.food_y == seg[1]:
return True
return False