This commit is contained in:
JohnJim0816
2021-03-28 11:18:52 +08:00
parent 2df8d965d2
commit 6e4d966e1f
56 changed files with 497 additions and 165 deletions

View File

@@ -77,7 +77,7 @@ class BlackjackEnv(gym.Env):
self.natural = natural
# Start the first game
self._reset() # Number of
self.n_actions = 2
self.action_dim = 2
def reset(self):
return self._reset()

View File

@@ -31,7 +31,7 @@ class CliffWalkingEnv(discrete.DiscreteEnv):
self.shape = (4, 12)
nS = np.prod(self.shape)
n_actions = 4
action_dim = 4
# Cliff Location
self._cliff = np.zeros(self.shape, dtype=np.bool)
@@ -41,7 +41,7 @@ class CliffWalkingEnv(discrete.DiscreteEnv):
P = {}
for s in range(nS):
position = np.unravel_index(s, self.shape)
P[s] = { a : [] for a in range(n_actions) }
P[s] = { a : [] for a in range(action_dim) }
P[s][UP] = self._calculate_transition_prob(position, [-1, 0])
P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1])
P[s][DOWN] = self._calculate_transition_prob(position, [1, 0])
@@ -51,7 +51,7 @@ class CliffWalkingEnv(discrete.DiscreteEnv):
isd = np.zeros(nS)
isd[np.ravel_multi_index((3,0), self.shape)] = 1.0
super(CliffWalkingEnv, self).__init__(nS, n_actions, P, isd)
super(CliffWalkingEnv, self).__init__(nS, action_dim, P, isd)
def render(self, mode='human', close=False):
self._render(mode, close)

View File

@@ -37,7 +37,7 @@ class GridworldEnv(discrete.DiscreteEnv):
self.shape = shape
nS = np.prod(shape)
n_actions = 4
action_dim = 4
MAX_Y = shape[0]
MAX_X = shape[1]
@@ -51,7 +51,7 @@ class GridworldEnv(discrete.DiscreteEnv):
y, x = it.multi_index
# P[s][a] = (prob, next_state, reward, is_done)
P[s] = {a : [] for a in range(n_actions)}
P[s] = {a : [] for a in range(action_dim)}
is_done = lambda s: s == 0 or s == (nS - 1)
reward = 0.0 if is_done(s) else -1.0
@@ -82,7 +82,7 @@ class GridworldEnv(discrete.DiscreteEnv):
# This should not be used in any model-free learning algorithm
self.P = P
super(GridworldEnv, self).__init__(nS, n_actions, P, isd)
super(GridworldEnv, self).__init__(nS, action_dim, P, isd)
def _render(self, mode='human', close=False):
""" Renders the current gridworld layout

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: John
Email: johnjim0816@gmail.com
Date: 2021-03-24 22:12:19
LastEditor: John
LastEditTime: 2021-03-26 17:12:43
Discription:
Environment:
'''
import numpy as np
import random
class StochasticMDP:
def __init__(self):
self.end = False
self.curr_state = 2
self.action_dim = 2
self.state_dim = 6
self.p_right = 0.5
def reset(self):
self.end = False
self.curr_state = 2
state = np.zeros(self.state_dim)
state[self.curr_state - 1] = 1.
return state
def step(self, action):
if self.curr_state != 1:
if action == 1:
if random.random() < self.p_right and self.curr_state < self.state_dim:
self.curr_state += 1
else:
self.curr_state -= 1
if action == 0:
self.curr_state -= 1
if self.curr_state == self.state_dim:
self.end = True
state = np.zeros(self.state_dim)
state[self.curr_state - 1] = 1.
if self.curr_state == 1:
if self.end:
return state, 1.00, True, {}
else:
return state, 1.00/100.00, True, {}
else:
return state, 0.0, False, {}

View File

@@ -30,7 +30,7 @@ class WindyGridworldEnv(discrete.DiscreteEnv):
self.shape = (7, 10)
nS = np.prod(self.shape)
n_actions = 4
action_dim = 4
# Wind strength
winds = np.zeros(self.shape)
@@ -41,7 +41,7 @@ class WindyGridworldEnv(discrete.DiscreteEnv):
P = {}
for s in range(nS):
position = np.unravel_index(s, self.shape)
P[s] = { a : [] for a in range(n_actions) }
P[s] = { a : [] for a in range(action_dim) }
P[s][UP] = self._calculate_transition_prob(position, [-1, 0], winds)
P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1], winds)
P[s][DOWN] = self._calculate_transition_prob(position, [1, 0], winds)
@@ -51,7 +51,7 @@ class WindyGridworldEnv(discrete.DiscreteEnv):
isd = np.zeros(nS)
isd[np.ravel_multi_index((3,0), self.shape)] = 1.0
super(WindyGridworldEnv, self).__init__(nS, n_actions, P, isd)
super(WindyGridworldEnv, self).__init__(nS, action_dim, P, isd)
def render(self, mode='human', close=False):
self._render(mode, close)