update RL_example.py

This commit is contained in:
qiwang067
2023-07-21 22:36:46 +08:00
parent 270e89d5b9
commit 64e99353a0

View File

@@ -2,7 +2,7 @@ import gym
import numpy as np
class BespokeAgent:
class Agent:
def __init__(self, env):
pass
@@ -21,7 +21,7 @@ class BespokeAgent:
pass
def play_montecarlo(env, agent, render=False, train=False):
def play(env, agent, render=False, train=False):
episode_reward = 0. # 记录回合总奖励初始化为0
observation = env.reset() # 重置游戏环境,开始新回合
while True: # 不断循环,直到回合结束
@@ -39,16 +39,16 @@ def play_montecarlo(env, agent, render=False, train=False):
env = gym.make('MountainCar-v0')
env.seed(3) # 设置随机种子,只是为了让结果可以精确复现,一般情况下可删去
agent = BespokeAgent(env)
env.seed(3) # 设置随机种子,让结果可复现
agent = Agent(env)
print('观测空间 = {}'.format(env.observation_space))
print('动作空间 = {}'.format(env.action_space))
print('观测范围 = {} ~ {}'.format(env.observation_space.low,
env.observation_space.high))
print('动作数 = {}'.format(env.action_space.n))
episode_reward = play_montecarlo(env, agent, render=True)
episode_reward = play(env, agent, render=True)
print('回合奖励 = {}'.format(episode_reward))
episode_rewards = [play_montecarlo(env, agent) for _ in range(100)]
episode_rewards = [play(env, agent) for _ in range(100)]
print('平均回合奖励 = {}'.format(np.mean(episode_rewards)))