Files
easy-rl/docs/chapter1/RL_example.py
2023-07-21 21:54:22 +08:00

55 lines
2.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import gym
import numpy as np
class BespokeAgent:
def __init__(self, env):
pass
def decide(self, observation): # 决策
position, velocity = observation
lb = min(-0.09 * (position + 0.25) ** 2 + 0.03,
0.3 * (position + 0.9) ** 4 - 0.008)
ub = -0.07 * (position + 0.38) ** 2 + 0.07
if lb < velocity < ub:
action = 2
else:
action = 0
return action # 返回动作
def learn(self, *args): # 学习
pass
def play_montecarlo(env, agent, render=False, train=False):
episode_reward = 0. # 记录回合总奖励初始化为0
observation = env.reset() # 重置游戏环境,开始新回合
while True: # 不断循环,直到回合结束
if render: # 判断是否显示
env.render() # 显示图形界面,图形界面可以用 env.close() 语句关闭
action = agent.decide(observation)
next_observation, reward, done, _ = env.step(action) # 执行动作
episode_reward += reward # 收集回合奖励
if train: # 判断是否训练智能体
agent.learn(observation, action, reward, done) # 学习
if done: # 回合结束,跳出循环
break
observation = next_observation
return episode_reward # 返回回合总奖励
env = gym.make('MountainCar-v0')
env.seed(3) # 设置随机数种子,只是为了让结果可以精确复现,一般情况下可删去
agent = BespokeAgent(env)
print('观测空间 = {}'.format(env.observation_space))
print('动作空间 = {}'.format(env.action_space))
print('观测范围 = {} ~ {}'.format(env.observation_space.low,
env.observation_space.high))
print('动作数 = {}'.format(env.action_space.n))
episode_reward = play_montecarlo(env, agent, render=True)
print('回合奖励 = {}'.format(episode_reward))
episode_rewards = [play_montecarlo(env, agent) for _ in range(100)]
print('平均回合奖励 = {}'.format(np.mean(episode_rewards)))