All codes
https://github.com/ColinFred/Reinforce_Learning_Pytorch/tree/main/RL/DQN
1, Priority playback
In empirical playback, uniformly distributed sampling is used, and this method does not seem to be efficient. For agents, the importance of these data is different. Therefore, a Prioritized Replay method is proposed. The basic idea of priority playback is to break the uniform sampling and give more sampling weight to the samples with high learning efficiency.
An ideal criterion is that the higher the efficiency of agent learning, the greater the weight. One option that meets this criterion is the TD deviation δ. The greater the TD deviation, the greater the gap between the value function in this state and the TD target, and the greater the update amount of the agent, so the higher the learning efficiency.
In short, sampling priority is added to each Transition in the original replay buffer
There are three main changes in priority playback DQN:
1. In order to facilitate priority playback, storage and sampling, sumTree tree tree is used for storage;
The original text has two methods to calculate the sample sampling probability: proportional priority and rank based priority. Proportional priority refers to the probability that the sample is sampled, which is proportional to the priority of TD deviation; Rank based priority is the rank whose probability is proportional to the Transition priority. Here, considering the proportional priority, the probability of Transition being drawn is directly proportional to the TD deviation.
Moreover, in order to ensure that each stored Transition can be sample d, the new Transition will be given a great priority.
2. When calculating the objective function, the weight is added according to the TD deviation of the sample (the weight is related to the TD deviation. The greater the deviation, the greater the weight):
1
m
∑
j
=
1
m
w
j
(
y
j
−
Q
(
s
j
,
a
j
,
w
)
)
2
\frac{1}{m}\sum\limits_{j=1}^m w_j (y_j-Q(s_j, a_j, w))^2
m1j=1∑mwj(yj−Q(sj,aj,w))2
3. Every time the Q network parameters are updated, the TD error needs to be recalculated δ j = y j − Q ( s j , a j , w ) \delta_j = y_j- Q(s_j, a_j, w) δj=yj−Q(sj,aj,w)
2, Code
Prioritized experience replay combines the previous Double DQN and Dueling DQN
SumTree and ReplayMemory_Per
SumTree mainly implements: add() to add experience; get() sampling by priority; update() updates the priority of a Transition.
ReplayMemory_Per mainly implements: push() inserts a new experience; sample() sample Transition by priority; Update priority () existing experience
class SumTree: write = 0 def __init__(self, capacity): self.capacity = capacity self.tree = np.zeros(2 * capacity - 1) self.data = np.zeros(capacity, dtype=object) self.n_entries = 0 # update to the root node def _propagate(self, idx, change): parent = (idx - 1) // 2 self.tree[parent] += change if parent != 0: self._propagate(parent, change) # find sample on leaf node def _retrieve(self, idx, s): left = 2 * idx + 1 right = left + 1 if left >= len(self.tree): return idx if s <= self.tree[left]: return self._retrieve(left, s) else: return self._retrieve(right, s - self.tree[left]) def total(self): return self.tree[0] # store priority and sample def add(self, p, data): idx = self.write + self.capacity - 1 self.data[self.write] = data self.update(idx, p) self.write += 1 if self.write >= self.capacity: self.write = 0 if self.n_entries < self.capacity: self.n_entries += 1 # update priority def update(self, idx, p): change = p - self.tree[idx] self.tree[idx] = p self._propagate(idx, change) # get priority and sample def get(self, s): idx = self._retrieve(0, s) dataIdx = idx - self.capacity + 1 return (idx, self.tree[idx], self.data[dataIdx]) class ReplayMemory_Per(object): # stored as ( s, a, r, s_ ) in SumTree def __init__(self, capacity=1000, a=0.6, e=0.01): self.tree = SumTree(capacity) self.memory_size = capacity self.prio_max = 0.1 self.a = a self.e = e def push(self, *args): data = Transition(*args) p = (np.abs(self.prio_max) + self.e) ** self.a # proportional priority self.tree.add(p, data) def sample(self, batch_size): idxs = [] segment = self.tree.total() / batch_size sample_datas = [] for i in range(batch_size): a = segment * i b = segment * (i + 1) s = uniform(a, b) idx, p, data = self.tree.get(s) sample_datas.append(data) idxs.append(idx) return idxs, sample_datas def update(self, idxs, errors): self.prio_max = max(self.prio_max, max(np.abs(errors))) for i, idx in enumerate(idxs): p = (np.abs(errors[i]) + self.e) ** self.a self.tree.update(idx, p) def size(self): return self.tree.n_entries
Every time the Q network parameters are updated, the TD error needs to be recalculated and the SumTree needs to be updated
class PerDQN: def __init__(self, n_action, n_state, learning_rate): self.n_action = n_action self.n_state = n_state self.memory = ReplayMemory_Per(capacity=100) self.memory_counter = 0 self.model_policy = DNN(self.n_state, self.n_action) self.model_target = DNN(self.n_state, self.n_action) self.model_target.load_state_dict(self.model_policy.state_dict()) self.model_target.eval() self.optimizer = optim.Adam(self.model_policy.parameters(), lr=learning_rate) def store_transition(self, s, a, r, s_): state = torch.FloatTensor([s]) action = torch.LongTensor([a]) reward = torch.FloatTensor([r]) next_state = torch.FloatTensor([s_]) self.memory.push(state, action, next_state, reward) def choose_action(self, state): state = torch.FloatTensor(state) if np.random.randn() <= EPISILO: # greedy policy with torch.no_grad(): q_value = self.model_policy(state) action = q_value.max(0)[1].view(1, 1).item() else: # random policy action = torch.tensor([randrange(self.n_action)], dtype=torch.long).item() return action def learn(self): if self.memory.size() < BATCH_SIZE: return idxs, transitions = self.memory.sample(BATCH_SIZE) batch = Transition(*zip(*transitions)) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action).unsqueeze(1) reward_batch = torch.cat(batch.reward) next_state_batch = torch.cat(batch.next_state) state_action_values = self.model_policy(state_batch).gather(1, action_batch) next_action_batch = torch.unsqueeze(self.model_policy(next_state_batch).max(1)[1], 1) next_state_values = self.model_target(next_state_batch).gather(1, next_action_batch) expected_state_action_values = (next_state_values * GAMMA) + reward_batch.unsqueeze(1) td_errors = (state_action_values - expected_state_action_values).detach().squeeze().tolist() self.memory.update(idxs, td_errors) # update td error loss = F.mse_loss(state_action_values, expected_state_action_values) self.optimizer.zero_grad() loss.backward() for param in self.model_policy.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step() def update_target_network(self): self.model_target.load_state_dict(self.model_policy.state_dict())
reference resources
- https://zhuanlan.zhihu.com/p/128176891
- https://www.cnblogs.com/jiangxinyang/p/10112381.html