Deep reinforcement learning (DRL) V: prioritized experience replay (dqn)

Posted by jokullsolberg on Tue, 08 Mar 2022 10:16:29 +0100

All codes

https://github.com/ColinFred/Reinforce_Learning_Pytorch/tree/main/RL/DQN

1, Priority playback

In empirical playback, uniformly distributed sampling is used, and this method does not seem to be efficient. For agents, the importance of these data is different. Therefore, a Prioritized Replay method is proposed. The basic idea of priority playback is to break the uniform sampling and give more sampling weight to the samples with high learning efficiency.

An ideal criterion is that the higher the efficiency of agent learning, the greater the weight. One option that meets this criterion is the TD deviation δ. The greater the TD deviation, the greater the gap between the value function in this state and the TD target, and the greater the update amount of the agent, so the higher the learning efficiency.

In short, sampling priority is added to each Transition in the original replay buffer

There are three main changes in priority playback DQN:

1. In order to facilitate priority playback, storage and sampling, sumTree tree tree is used for storage;

The original text has two methods to calculate the sample sampling probability: proportional priority and rank based priority. Proportional priority refers to the probability that the sample is sampled, which is proportional to the priority of TD deviation; Rank based priority is the rank whose probability is proportional to the Transition priority. Here, considering the proportional priority, the probability of Transition being drawn is directly proportional to the TD deviation.

Moreover, in order to ensure that each stored Transition can be sample d, the new Transition will be given a great priority.

2. When calculating the objective function, the weight is added according to the TD deviation of the sample (the weight is related to the TD deviation. The greater the deviation, the greater the weight):
1 m ∑ j = 1 m w j ( y j − Q ( s j , a j , w ) ) 2 \frac{1}{m}\sum\limits_{j=1}^m w_j (y_j-Q(s_j, a_j, w))^2 m1​j=1∑m​wj​(yj​−Q(sj​,aj​,w))2

3. Every time the Q network parameters are updated, the TD error needs to be recalculated δ j = y j − Q ( s j , a j , w ) \delta_j = y_j- Q(s_j, a_j, w) δj​=yj​−Q(sj​,aj​,w)

2, Code

Prioritized experience replay combines the previous Double DQN and Dueling DQN

SumTree and ReplayMemory_Per

SumTree mainly implements: add() to add experience; get() sampling by priority; update() updates the priority of a Transition.

ReplayMemory_Per mainly implements: push() inserts a new experience; sample() sample Transition by priority; Update priority () existing experience

class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])


class ReplayMemory_Per(object):
    # stored as ( s, a, r, s_ ) in SumTree
    def __init__(self, capacity=1000, a=0.6, e=0.01):
        self.tree = SumTree(capacity)
        self.memory_size = capacity
        self.prio_max = 0.1
        self.a = a
        self.e = e

    def push(self, *args):
        data = Transition(*args)
        p = (np.abs(self.prio_max) + self.e) ** self.a  # proportional priority
        self.tree.add(p, data)

    def sample(self, batch_size):
        idxs = []
        segment = self.tree.total() / batch_size
        sample_datas = []

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = uniform(a, b)
            idx, p, data = self.tree.get(s)

            sample_datas.append(data)
            idxs.append(idx)
        return idxs, sample_datas

    def update(self, idxs, errors):
        self.prio_max = max(self.prio_max, max(np.abs(errors)))
        for i, idx in enumerate(idxs):
            p = (np.abs(errors[i]) + self.e) ** self.a
            self.tree.update(idx, p)

    def size(self):
        return self.tree.n_entries

Every time the Q network parameters are updated, the TD error needs to be recalculated and the SumTree needs to be updated


class PerDQN:
    def __init__(self, n_action, n_state, learning_rate):

        self.n_action = n_action
        self.n_state = n_state

        self.memory = ReplayMemory_Per(capacity=100)
        self.memory_counter = 0

        self.model_policy = DNN(self.n_state, self.n_action)
        self.model_target = DNN(self.n_state, self.n_action)
        self.model_target.load_state_dict(self.model_policy.state_dict())
        self.model_target.eval()

        self.optimizer = optim.Adam(self.model_policy.parameters(), lr=learning_rate)

    def store_transition(self, s, a, r, s_):
        state = torch.FloatTensor([s])
        action = torch.LongTensor([a])
        reward = torch.FloatTensor([r])
        next_state = torch.FloatTensor([s_])
        self.memory.push(state, action, next_state, reward)

    def choose_action(self, state):
        state = torch.FloatTensor(state)
        if np.random.randn() <= EPISILO:  # greedy policy
            with torch.no_grad():
                q_value = self.model_policy(state)
                action = q_value.max(0)[1].view(1, 1).item()
        else:  # random policy
            action = torch.tensor([randrange(self.n_action)], dtype=torch.long).item()

        return action

    def learn(self):
        if self.memory.size() < BATCH_SIZE:
            return
        idxs, transitions = self.memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action).unsqueeze(1)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)

        state_action_values = self.model_policy(state_batch).gather(1, action_batch)

        next_action_batch = torch.unsqueeze(self.model_policy(next_state_batch).max(1)[1], 1)
        next_state_values = self.model_target(next_state_batch).gather(1, next_action_batch)
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch.unsqueeze(1)

        td_errors = (state_action_values - expected_state_action_values).detach().squeeze().tolist()
        self.memory.update(idxs, td_errors)  # update td error
        loss = F.mse_loss(state_action_values, expected_state_action_values)

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.model_policy.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def update_target_network(self):
        self.model_target.load_state_dict(self.model_policy.state_dict())

reference resources

  1. https://zhuanlan.zhihu.com/p/128176891
  2. https://www.cnblogs.com/jiangxinyang/p/10112381.html