add DQN_cnn
This commit is contained in:
66
codes/DQN_cnn/env.py
Normal file
66
codes/DQN_cnn/env.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
@Author: John
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-11 10:02:35
|
||||
@LastEditor: John
|
||||
@LastEditTime: 2020-06-11 16:57:34
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
resize = T.Compose([T.ToPILImage(),
|
||||
T.Resize(40, interpolation=Image.CUBIC),
|
||||
T.ToTensor()])
|
||||
|
||||
|
||||
def get_cart_location(env,screen_width):
|
||||
world_width = env.x_threshold * 2
|
||||
scale = screen_width / world_width
|
||||
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
|
||||
|
||||
def get_screen(env,device):
|
||||
# Returned screen requested by gym is 400x600x3, but is sometimes larger
|
||||
# such as 800x1200x3. Transpose it into torch order (CHW).
|
||||
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
|
||||
# Cart is in the lower half, so strip off the top and bottom of the screen
|
||||
_, screen_height, screen_width = screen.shape
|
||||
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
|
||||
view_width = int(screen_width * 0.6)
|
||||
cart_location = get_cart_location(env,screen_width)
|
||||
if cart_location < view_width // 2:
|
||||
slice_range = slice(view_width)
|
||||
elif cart_location > (screen_width - view_width // 2):
|
||||
slice_range = slice(-view_width, None)
|
||||
else:
|
||||
slice_range = slice(cart_location - view_width // 2,
|
||||
cart_location + view_width // 2)
|
||||
# Strip off the edges, so that we have a square image centered on a cart
|
||||
screen = screen[:, :, slice_range]
|
||||
# Convert to float, rescale, convert to torch tensor
|
||||
# (this doesn't require a copy)
|
||||
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
|
||||
screen = torch.from_numpy(screen)
|
||||
# Resize, and add a batch dimension (BCHW)
|
||||
return resize(screen).unsqueeze(0).to(device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
import gym
|
||||
env = gym.make('CartPole-v0').unwrapped
|
||||
# if gpu is to be used
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
env.reset()
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure()
|
||||
plt.imshow(get_screen(env,device).cpu().squeeze(0).permute(1, 2, 0).numpy(),
|
||||
interpolation='none')
|
||||
plt.title('Example extracted screen')
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user