hot update
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
@Email: johnjim0816@gmail.com
|
||||
@Date: 2020-06-12 00:50:49
|
||||
@LastEditor: John
|
||||
LastEditTime: 2022-08-18 14:27:18
|
||||
LastEditTime: 2022-08-23 23:59:54
|
||||
@Discription:
|
||||
@Environment: python 3.7.7
|
||||
'''
|
||||
@@ -20,26 +20,26 @@ import math
|
||||
import numpy as np
|
||||
|
||||
class DQN:
|
||||
def __init__(self,n_actions,model,memory,cfg):
|
||||
def __init__(self,model,memory,cfg):
|
||||
|
||||
self.n_actions = n_actions
|
||||
self.device = torch.device(cfg.device)
|
||||
self.gamma = cfg.gamma
|
||||
self.n_actions = cfg['n_actions']
|
||||
self.device = torch.device(cfg['device'])
|
||||
self.gamma = cfg['gamma']
|
||||
## e-greedy parameters
|
||||
self.sample_count = 0 # sample count for epsilon decay
|
||||
self.epsilon = cfg.epsilon_start
|
||||
self.epsilon = cfg['epsilon_start']
|
||||
self.sample_count = 0
|
||||
self.epsilon_start = cfg.epsilon_start
|
||||
self.epsilon_end = cfg.epsilon_end
|
||||
self.epsilon_decay = cfg.epsilon_decay
|
||||
self.batch_size = cfg.batch_size
|
||||
self.epsilon_start = cfg['epsilon_start']
|
||||
self.epsilon_end = cfg['epsilon_end']
|
||||
self.epsilon_decay = cfg['epsilon_decay']
|
||||
self.batch_size = cfg['batch_size']
|
||||
self.policy_net = model.to(self.device)
|
||||
self.target_net = model.to(self.device)
|
||||
## copy parameters from policy net to target net
|
||||
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()):
|
||||
target_param.data.copy_(param.data)
|
||||
# self.target_net.load_state_dict(self.policy_net.state_dict()) # or use this to copy parameters
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg['lr'])
|
||||
self.memory = memory
|
||||
self.update_flag = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user