# Deep reinforcement learning (DRL) V: prioritized experience replay (dqn)

Posted by jokullsolberg on Tue, 08 Mar 2022 10:16:29 +0100

All codes

https://github.com/ColinFred/Reinforce_Learning_Pytorch/tree/main/RL/DQN

## 1, Priority playback

In empirical playback, uniformly distributed sampling is used, and this method does not seem to be efficient. For agents, the importance of these data is different. Therefore, a Prioritized Replay method is proposed. The basic idea of priority playback is to break the uniform sampling and give more sampling weight to the samples with high learning efficiency.

An ideal criterion is that the higher the efficiency of agent learning, the greater the weight. One option that meets this criterion is the TD deviation δ. The greater the TD deviation, the greater the gap between the value function in this state and the TD target, and the greater the update amount of the agent, so the higher the learning efficiency.

In short, sampling priority is added to each Transition in the original replay buffer

There are three main changes in priority playback DQN:

1. In order to facilitate priority playback, storage and sampling, sumTree tree tree is used for storage;

The original text has two methods to calculate the sample sampling probability: proportional priority and rank based priority. Proportional priority refers to the probability that the sample is sampled, which is proportional to the priority of TD deviation; Rank based priority is the rank whose probability is proportional to the Transition priority. Here, considering the proportional priority, the probability of Transition being drawn is directly proportional to the TD deviation.

Moreover, in order to ensure that each stored Transition can be sample d, the new Transition will be given a great priority.

2. When calculating the objective function, the weight is added according to the TD deviation of the sample (the weight is related to the TD deviation. The greater the deviation, the greater the weight):
1 m ∑ j = 1 m w j ( y j − Q ( s j , a j , w ) ) 2 \frac{1}{m}\sum\limits_{j=1}^m w_j (y_j-Q(s_j, a_j, w))^2 m1​j=1∑m​wj​(yj​−Q(sj​,aj​,w))2

3. Every time the Q network parameters are updated, the TD error needs to be recalculated δ j = y j − Q ( s j , a j , w ) \delta_j = y_j- Q(s_j, a_j, w) δj​=yj​−Q(sj​,aj​,w)

## 2, Code

Prioritized experience replay combines the previous Double DQN and Dueling DQN

SumTree and ReplayMemory_Per

SumTree mainly implements: add() to add experience; get() sampling by priority; update() updates the priority of a Transition.

ReplayMemory_Per mainly implements: push() inserts a new experience; sample() sample Transition by priority; Update priority () existing experience

class SumTree:
write = 0

def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.n_entries = 0

# update to the root node
def _propagate(self, idx, change):
parent = (idx - 1) // 2

self.tree[parent] += change

if parent != 0:
self._propagate(parent, change)

# find sample on leaf node
def _retrieve(self, idx, s):
left = 2 * idx + 1
right = left + 1

if left >= len(self.tree):
return idx

if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s - self.tree[left])

def total(self):
return self.tree

# store priority and sample
def add(self, p, data):
idx = self.write + self.capacity - 1

self.data[self.write] = data
self.update(idx, p)

self.write += 1
if self.write >= self.capacity:
self.write = 0

if self.n_entries < self.capacity:
self.n_entries += 1

# update priority
def update(self, idx, p):
change = p - self.tree[idx]

self.tree[idx] = p
self._propagate(idx, change)

# get priority and sample
def get(self, s):
idx = self._retrieve(0, s)
dataIdx = idx - self.capacity + 1

return (idx, self.tree[idx], self.data[dataIdx])

class ReplayMemory_Per(object):
# stored as ( s, a, r, s_ ) in SumTree
def __init__(self, capacity=1000, a=0.6, e=0.01):
self.tree = SumTree(capacity)
self.memory_size = capacity
self.prio_max = 0.1
self.a = a
self.e = e

def push(self, *args):
data = Transition(*args)
p = (np.abs(self.prio_max) + self.e) ** self.a  # proportional priority

def sample(self, batch_size):
idxs = []
segment = self.tree.total() / batch_size
sample_datas = []

for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
s = uniform(a, b)
idx, p, data = self.tree.get(s)

sample_datas.append(data)
idxs.append(idx)
return idxs, sample_datas

def update(self, idxs, errors):
self.prio_max = max(self.prio_max, max(np.abs(errors)))
for i, idx in enumerate(idxs):
p = (np.abs(errors[i]) + self.e) ** self.a
self.tree.update(idx, p)

def size(self):
return self.tree.n_entries


Every time the Q network parameters are updated, the TD error needs to be recalculated and the SumTree needs to be updated


class PerDQN:
def __init__(self, n_action, n_state, learning_rate):

self.n_action = n_action
self.n_state = n_state

self.memory = ReplayMemory_Per(capacity=100)
self.memory_counter = 0

self.model_policy = DNN(self.n_state, self.n_action)
self.model_target = DNN(self.n_state, self.n_action)
self.model_target.eval()

self.optimizer = optim.Adam(self.model_policy.parameters(), lr=learning_rate)

def store_transition(self, s, a, r, s_):
state = torch.FloatTensor([s])
action = torch.LongTensor([a])
reward = torch.FloatTensor([r])
next_state = torch.FloatTensor([s_])
self.memory.push(state, action, next_state, reward)

def choose_action(self, state):
state = torch.FloatTensor(state)
if np.random.randn() <= EPISILO:  # greedy policy
q_value = self.model_policy(state)
action = q_value.max(0).view(1, 1).item()
else:  # random policy
action = torch.tensor([randrange(self.n_action)], dtype=torch.long).item()

return action

def learn(self):
if self.memory.size() < BATCH_SIZE:
return
idxs, transitions = self.memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))

state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action).unsqueeze(1)
reward_batch = torch.cat(batch.reward)
next_state_batch = torch.cat(batch.next_state)

state_action_values = self.model_policy(state_batch).gather(1, action_batch)

next_action_batch = torch.unsqueeze(self.model_policy(next_state_batch).max(1), 1)
next_state_values = self.model_target(next_state_batch).gather(1, next_action_batch)
expected_state_action_values = (next_state_values * GAMMA) + reward_batch.unsqueeze(1)

td_errors = (state_action_values - expected_state_action_values).detach().squeeze().tolist()
self.memory.update(idxs, td_errors)  # update td error
loss = F.mse_loss(state_action_values, expected_state_action_values)

loss.backward()
for param in self.model_policy.parameters():