This commit is contained in:
johnjim0816
2022-07-21 00:13:44 +08:00
parent bab7f6fe8c
commit 0f38e23baf
34 changed files with 665 additions and 422 deletions

View File

@@ -73,11 +73,11 @@ class Critic(nn.Module):
return x
class DDPG:
def __init__(self, n_states, n_actions, cfg):
self.device = cfg.device
self.critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.target_critic = Critic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.target_actor = Actor(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
self.device = torch.device(cfg.device)
self.critic = Critic(n_states, n_actions, cfg.hidden_dim).to(self.device)
self.actor = Actor(n_states, n_actions, cfg.hidden_dim).to(self.device)
self.target_critic = Critic(n_states, n_actions, cfg.hidden_dim).to(self.device)
self.target_actor = Actor(n_states, n_actions, cfg.hidden_dim).to(self.device)
# 复制参数到目标网络
for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):

View File

@@ -0,0 +1,18 @@
{
"algo_name": "DDPG",
"env_name": "Pendulum-v1",
"train_eps": 300,
"test_eps": 20,
"gamma": 0.99,
"critic_lr": 0.001,
"actor_lr": 0.0001,
"memory_capacity": 8000,
"batch_size": 128,
"target_update": 2,
"soft_tau": 0.01,
"hidden_dim": 256,
"deivce": "cpu",
"result_path": "C:\\Users\\24438\\Desktop\\rl-tutorials/outputs/DDPG/outputs/Pendulum-v1/20220713-225402/results//",
"model_path": "C:\\Users\\24438\\Desktop\\rl-tutorials/outputs/DDPG/outputs/Pendulum-v1/20220713-225402/models/",
"save_fig": true
}

View File

@@ -1,18 +0,0 @@
------------------ start ------------------
algo_name : DDPG
env_name : Pendulum-v1
train_eps : 300
test_eps : 20
gamma : 0.99
critic_lr : 0.001
actor_lr : 0.0001
memory_capacity : 8000
batch_size : 128
target_update : 2
soft_tau : 0.01
hidden_dim : 256
result_path : c:\Users\24438\Desktop\rl-tutorials\codes\DDPG/outputs/Pendulum-v1/20220713-225402/results/
model_path : c:\Users\24438\Desktop\rl-tutorials\codes\DDPG/outputs/Pendulum-v1/20220713-225402/models/
save_fig : True
device : cuda
------------------- end -------------------

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-11 20:58:21
@LastEditor: John
LastEditTime: 2022-07-13 22:53:11
LastEditTime: 2022-07-21 00:05:41
@Discription:
@Environment: python 3.7.7
'''
@@ -41,14 +41,13 @@ def get_args():
parser.add_argument('--target_update',default=2,type=int)
parser.add_argument('--soft_tau',default=1e-2,type=float)
parser.add_argument('--hidden_dim',default=256,type=int)
parser.add_argument('--device',default='cpu',type=str,help="cpu or cuda")
parser.add_argument('--result_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/results/' )
parser.add_argument('--model_path',default=curr_path + "/outputs/" + parser.parse_args().env_name + \
'/' + curr_time + '/models/' ) # path to save models
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
args.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu") # check GPU
parser.add_argument('--save_fig',default=True,type=bool,help="if save figure or not")
args = parser.parse_args()
return args
def env_agent_config(cfg,seed=1):
@@ -122,11 +121,11 @@ if __name__ == "__main__":
save_args(cfg)
agent.save(path=cfg.model_path)
save_results(rewards, ma_rewards, tag='train', path=cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="train") # 画出结果
plot_rewards(rewards, ma_rewards, cfg, tag="train")
# testing
env,agent = env_agent_config(cfg,seed=10)
agent.load(path=cfg.model_path)
rewards,ma_rewards = test(cfg,env,agent)
save_results(rewards,ma_rewards,tag = 'test',path = cfg.result_path)
plot_rewards(rewards, ma_rewards, cfg, tag="test") # 画出结果
plot_rewards(rewards, ma_rewards, cfg, tag="test")