66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
@Author: John
|
|
@Email: johnjim0816@gmail.com
|
|
@Date: 2020-06-11 10:02:35
|
|
@LastEditor: John
|
|
@LastEditTime: 2020-06-11 16:57:34
|
|
@Discription:
|
|
@Environment: python 3.7.7
|
|
'''
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
|
|
resize = T.Compose([T.ToPILImage(),
|
|
T.Resize(40, interpolation=Image.CUBIC),
|
|
T.ToTensor()])
|
|
|
|
|
|
def get_cart_location(env,screen_width):
|
|
world_width = env.x_threshold * 2
|
|
scale = screen_width / world_width
|
|
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
|
|
|
|
def get_screen(env,device):
|
|
# Returned screen requested by gym is 400x600x3, but is sometimes larger
|
|
# such as 800x1200x3. Transpose it into torch order (CHW).
|
|
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
|
|
# Cart is in the lower half, so strip off the top and bottom of the screen
|
|
_, screen_height, screen_width = screen.shape
|
|
screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
|
|
view_width = int(screen_width * 0.6)
|
|
cart_location = get_cart_location(env,screen_width)
|
|
if cart_location < view_width // 2:
|
|
slice_range = slice(view_width)
|
|
elif cart_location > (screen_width - view_width // 2):
|
|
slice_range = slice(-view_width, None)
|
|
else:
|
|
slice_range = slice(cart_location - view_width // 2,
|
|
cart_location + view_width // 2)
|
|
# Strip off the edges, so that we have a square image centered on a cart
|
|
screen = screen[:, :, slice_range]
|
|
# Convert to float, rescale, convert to torch tensor
|
|
# (this doesn't require a copy)
|
|
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
|
|
screen = torch.from_numpy(screen)
|
|
# Resize, and add a batch dimension (BCHW)
|
|
return resize(screen).unsqueeze(0).to(device)
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import gym
|
|
env = gym.make('CartPole-v0').unwrapped
|
|
# if gpu is to be used
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
env.reset()
|
|
import matplotlib.pyplot as plt
|
|
|
|
plt.figure()
|
|
plt.imshow(get_screen(env,device).cpu().squeeze(0).permute(1, 2, 0).numpy(),
|
|
interpolation='none')
|
|
plt.title('Example extracted screen')
|
|
plt.show() |