29 lines
945 B
Python
29 lines
945 B
Python
import gym
|
|
|
|
class TimeLimit(gym.Wrapper):
|
|
def __init__(self, env, max_episode_steps=None):
|
|
super(TimeLimit, self).__init__(env)
|
|
self._max_episode_steps = max_episode_steps
|
|
self._elapsed_steps = 0
|
|
|
|
def step(self, ac):
|
|
observation, reward, done, info = self.env.step(ac)
|
|
self._elapsed_steps += 1
|
|
if self._elapsed_steps >= self._max_episode_steps:
|
|
done = True
|
|
info['TimeLimit.truncated'] = True
|
|
return observation, reward, done, info
|
|
|
|
def reset(self, **kwargs):
|
|
self._elapsed_steps = 0
|
|
return self.env.reset(**kwargs)
|
|
|
|
class ClipActionsWrapper(gym.Wrapper):
|
|
def step(self, action):
|
|
import numpy as np
|
|
action = np.nan_to_num(action)
|
|
action = np.clip(action, self.action_space.low, self.action_space.high)
|
|
return self.env.step(action)
|
|
|
|
def reset(self, **kwargs):
|
|
return self.env.reset(**kwargs) |