This commit is contained in:
johnjim0816
2022-06-09 19:06:37 +08:00
parent 46f71ddb81
commit 621c81278d
4 changed files with 35 additions and 9 deletions

5
codes/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
.DS_Store
.ipynb_checkpoints
__pycache__
.vscode
test.py

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-09 20:25:52
@LastEditor: John
LastEditTime: 2021-09-16 00:55:30
LastEditTime: 2022-06-09 19:04:44
@Discription:
@Environment: python 3.7.7
'''
@@ -104,9 +104,9 @@ class DDPG:
# 从经验回放中(replay memory)中随机采样一个批量的转移(transition)
state, action, reward, next_state, done = self.memory.sample(self.batch_size)
# 转变为张量
state = torch.FloatTensor(state).to(self.device)
next_state = torch.FloatTensor(next_state).to(self.device)
action = torch.FloatTensor(action).to(self.device)
state = torch.FloatTensor(np.array(state)).to(self.device)
next_state = torch.FloatTensor(np.array(next_state)).to(self.device)
action = torch.FloatTensor(np.array(action)).to(self.device)
reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device)
done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(self.device)

View File

@@ -5,11 +5,12 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-11 20:58:21
@LastEditor: John
LastEditTime: 2022-02-10 06:23:27
LastEditTime: 2022-06-09 19:05:20
@Discription:
@Environment: python 3.7.7
'''
import sys,os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
curr_path = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在绝对路径
parent_path = os.path.dirname(curr_path) # 父路径
sys.path.append(parent_path) # 添加路径到系统路径sys.path
@@ -20,7 +21,6 @@ import torch
from env import NormalizedActions,OUNoise
from ddpg import DDPG
from DDPG.train import train,test
from common.utils import save_results,make_dir
from common.utils import plot_rewards
@@ -37,7 +37,7 @@ class Config:
"cuda" if torch.cuda.is_available() else "cpu") # 检测GPUgjgjlkhfsf风刀霜的撒发十
self.seed = 10 # 随机种子置0则不设置随机种子
self.train_eps = 300 # 训练的回合数
self.test_eps = 50 # 测试的回合数
self.test_eps = 20 # 测试的回合数
################################################################################
################################## 算法超参数 ###################################
@@ -68,7 +68,7 @@ def env_agent_config(cfg,seed=1):
return env,agent
def train(cfg, env, agent):
print('开始训练!')
print(f'环境:{cfg.env_name},算法:{cfg.algo},设备:{cfg.device}')
print(f'环境:{cfg.env_name},算法:{cfg.algo_name},设备:{cfg.device}')
ou_noise = OUNoise(env.action_space) # 动作噪声
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
@@ -99,7 +99,7 @@ def train(cfg, env, agent):
def test(cfg, env, agent):
print('开始测试!')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo}, 设备:{cfg.device}')
print(f'环境:{cfg.env_name}, 算法:{cfg.algo_name}, 设备:{cfg.device}')
rewards = [] # 记录所有回合的奖励
ma_rewards = [] # 记录所有回合的滑动平均奖励
for i_ep in range(cfg.test_eps):

21
codes/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2020 John Jim
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.