hot update

This commit is contained in:
johnjim0816
2022-08-24 11:33:06 +08:00
parent ad65dd17cd
commit 62a7364c72
40 changed files with 2129 additions and 179 deletions

View File

@@ -5,7 +5,7 @@
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2022-08-18 14:27:18
LastEditTime: 2022-08-23 23:59:54
@Discription:
@Environment: python 3.7.7
'''
@@ -20,26 +20,26 @@ import math
import numpy as np
class DQN:
def __init__(self,n_actions,model,memory,cfg):
def __init__(self,model,memory,cfg):
self.n_actions = n_actions
self.device = torch.device(cfg.device)
self.gamma = cfg.gamma
self.n_actions = cfg['n_actions']
self.device = torch.device(cfg['device'])
self.gamma = cfg['gamma']
## e-greedy parameters
self.sample_count = 0 # sample count for epsilon decay
self.epsilon = cfg.epsilon_start
self.epsilon = cfg['epsilon_start']
self.sample_count = 0
self.epsilon_start = cfg.epsilon_start
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.batch_size = cfg.batch_size
self.epsilon_start = cfg['epsilon_start']
self.epsilon_end = cfg['epsilon_end']
self.epsilon_decay = cfg['epsilon_decay']
self.batch_size = cfg['batch_size']
self.policy_net = model.to(self.device)
self.target_net = model.to(self.device)
## copy parameters from policy net to target net
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()):
target_param.data.copy_(param.data)
# self.target_net.load_state_dict(self.policy_net.state_dict()) # or use this to copy parameters
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg['lr'])
self.memory = memory
self.update_flag = False