#!/usr/bin/env python # coding=utf-8 ''' @Author: John @Email: johnjim0816@gmail.com @Date: 2020-06-11 10:02:35 @LastEditor: John @LastEditTime: 2020-06-11 16:57:34 @Discription: @Environment: python 3.7.7 ''' import numpy as np import torch import torchvision.transforms as T from PIL import Image resize = T.Compose([T.ToPILImage(), T.Resize(40, interpolation=Image.CUBIC), T.ToTensor()]) def get_cart_location(env,screen_width): world_width = env.x_threshold * 2 scale = screen_width / world_width return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART def get_screen(env,device): # Returned screen requested by gym is 400x600x3, but is sometimes larger # such as 800x1200x3. Transpose it into torch order (CHW). screen = env.render(mode='rgb_array').transpose((2, 0, 1)) # Cart is in the lower half, so strip off the top and bottom of the screen _, screen_height, screen_width = screen.shape screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)] view_width = int(screen_width * 0.6) cart_location = get_cart_location(env,screen_width) if cart_location < view_width // 2: slice_range = slice(view_width) elif cart_location > (screen_width - view_width // 2): slice_range = slice(-view_width, None) else: slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2) # Strip off the edges, so that we have a square image centered on a cart screen = screen[:, :, slice_range] # Convert to float, rescale, convert to torch tensor # (this doesn't require a copy) screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 screen = torch.from_numpy(screen) # Resize, and add a batch dimension (BCHW) return resize(screen).unsqueeze(0).to(device) if __name__ == "__main__": import gym env = gym.make('CartPole-v0').unwrapped # if gpu is to be used device = torch.device("cuda" if torch.cuda.is_available() else "cpu") env.reset() import matplotlib.pyplot as plt plt.figure() plt.imshow(get_screen(env,device).cpu().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none') plt.title('Example extracted screen') plt.show()