pytorch learning 11: where and gather

Posted by ee12csvt on Thu, 30 Dec 2021 21:36:41 +0100

where

Where syntax: torch where(condition, x, y)

  • Requirement: the order (shape) of condition, x and y matrices must be the same, and the return value is also a sentence of the same order
  • Condition: condition matrix. When the element is True, fill in the corresponding element in x; otherwise, fill in the corresponding element in y
  • x: When the element in condition is True, the element is selected from X
  • y: When the element in condition is False, the element is selected from y

Example 1: c = max(a, b)

import torch

a = torch.randn(3, 3)
b = torch.randn(3, 3)

print("a:\n{}\n".format(a))
print("b:\n{}\n".format(b))

c = torch.where(a>b, a, b)

print("c:\n{}\n".format(c))

Example 2: classification prediction

import torch

"""
hypothesis p It is the probability after the second classification prediction
 When p>0.5 Identification is 1 when
 Otherwise, the ID is 0
"""

p = torch.rand(3, 3)

print("p:\n{}\n".format(p))

c = torch.where(p>0.5, 1, 0)

print("c:\n{}\n".format(c))

gather

gather syntax: torch gather (input, dim, index). The following variable descriptions are written according to your own understanding

  • input: replace dictionary matrix
  • dim: replace direction
  • index: matrix to be replaced, type must be int64

one-dimensional

Example:

import torch

_index = torch.trunc(torch.rand(8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 0, _index)
print("_output:\n{}\n".format(_output))

  • The number of columns in input is required to be greater than the maximum value of the element in index.
  • It can be understood that the element in index is regarded as the index of input, the index is replaced with the corresponding element in input, and then the replaced matrix is returned.
  • The returned matrix is the same order as the index shape

two-dimensional

dim=1

Example 1: same as input

import torch

_index = torch.trunc(torch.rand(2, 8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

# Expand input to 2x10
_input = _input.expand(2, 10)
print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 1, _index)
print("_output:\n{}\n".format(_output))

  • The number of input rows is required to be the same as index, and the number of columns is required to be the same as one dimension
  • It can be understood as two one-dimensional gather s
  • Replace the two rows of index data in the same middle

Example 2: different input

import torch

_index = torch.trunc(torch.rand(2, 8) * 10).long()

_input = torch.tensor(
    [[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
     [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])

print("_index:\n{}\n".format(_index))

print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 1, _index)
print("_output:\n{}\n".format(_output))

  • The input line is required to be consistent with the previous example
  • It can be understood that the two lines of index are replaced by different standards

dim=0

Example:

import torch

_index = torch.trunc(torch.rand(2, 8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

# Switch the input to the column vector and expand the column to be the same as the index column
_input = _input.view(10, 1).expand(10, 8)
print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 0, _index)
print("_output:\n{}\n".format(_output))

  • Similar to dim=1, it can be understood as replacing each column
  • The input requirement is also consistent with dim=1

three-dimensional

import torch

_index = torch.trunc(torch.rand(2, 2, 8) * 10).long()

_input = torch.tensor(
    [10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

print("_index:\n{}\n".format(_index))

_input = _input.expand(2, 2, 10)
print("_input:\n{}\n".format(_input))
_output = torch.gather(_input, 2, _index)
print("_output:\n{}\n".format(_output))

  • Similar to one and two dimensions

Some understanding of gather

  • input requirements are:

    The dimension is consistent with index, and the length of each dimension must be the same as the corresponding dimension of index except the dimension specified by dim.

    The dimension length specified by dim should be greater than the maximum value of index element.

  • Whether each row / column of input is the same:

    Although you can customize the input to distinguish between different rows / columns, so as to change the content replaced by index. However, this is of little significance, and it will be troublesome to customize input in high dimensions. Personally, the most commonly used is to let input replace all indexes to be replaced.

  • About gather:

    The use of gather can be understood as the replacement of Chinese characters and Pinyin. One dimensional represents the replacement of a sentence, two-dimensional represents the replacement of an article, and three-dimensional represents the replacement of a book. For example: [I, love, you] → [wo, ai, ni].

summary

The functions of where and gather can be realized by writing code by ourselves, but the bottom layer of where and gather is better than the code written by ourselves, so the speed is much faster. So use where and gather as much as you can.

Topics: AI Pytorch Deep Learning