where
Where syntax: torch where(condition, x, y)
- Requirement: the order (shape) of condition, x and y matrices must be the same, and the return value is also a sentence of the same order
- Condition: condition matrix. When the element is True, fill in the corresponding element in x; otherwise, fill in the corresponding element in y
- x: When the element in condition is True, the element is selected from X
- y: When the element in condition is False, the element is selected from y
Example 1: c = max(a, b)
import torch a = torch.randn(3, 3) b = torch.randn(3, 3) print("a:\n{}\n".format(a)) print("b:\n{}\n".format(b)) c = torch.where(a>b, a, b) print("c:\n{}\n".format(c))
Example 2: classification prediction
import torch """ hypothesis p It is the probability after the second classification prediction When p>0.5 Identification is 1 when Otherwise, the ID is 0 """ p = torch.rand(3, 3) print("p:\n{}\n".format(p)) c = torch.where(p>0.5, 1, 0) print("c:\n{}\n".format(c))
gather
gather syntax: torch gather (input, dim, index). The following variable descriptions are written according to your own understanding
- input: replace dictionary matrix
- dim: replace direction
- index: matrix to be replaced, type must be int64
one-dimensional
Example:
import torch _index = torch.trunc(torch.rand(8) * 10).long() _input = torch.tensor( [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]) print("_index:\n{}\n".format(_index)) print("_input:\n{}\n".format(_input)) _output = torch.gather(_input, 0, _index) print("_output:\n{}\n".format(_output))
- The number of columns in input is required to be greater than the maximum value of the element in index.
- It can be understood that the element in index is regarded as the index of input, the index is replaced with the corresponding element in input, and then the replaced matrix is returned.
- The returned matrix is the same order as the index shape
two-dimensional
dim=1
Example 1: same as input
import torch _index = torch.trunc(torch.rand(2, 8) * 10).long() _input = torch.tensor( [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]) print("_index:\n{}\n".format(_index)) # Expand input to 2x10 _input = _input.expand(2, 10) print("_input:\n{}\n".format(_input)) _output = torch.gather(_input, 1, _index) print("_output:\n{}\n".format(_output))
- The number of input rows is required to be the same as index, and the number of columns is required to be the same as one dimension
- It can be understood as two one-dimensional gather s
- Replace the two rows of index data in the same middle
Example 2: different input
import torch _index = torch.trunc(torch.rand(2, 8) * 10).long() _input = torch.tensor( [[10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]]) print("_index:\n{}\n".format(_index)) print("_input:\n{}\n".format(_input)) _output = torch.gather(_input, 1, _index) print("_output:\n{}\n".format(_output))
- The input line is required to be consistent with the previous example
- It can be understood that the two lines of index are replaced by different standards
dim=0
Example:
import torch _index = torch.trunc(torch.rand(2, 8) * 10).long() _input = torch.tensor( [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]) print("_index:\n{}\n".format(_index)) # Switch the input to the column vector and expand the column to be the same as the index column _input = _input.view(10, 1).expand(10, 8) print("_input:\n{}\n".format(_input)) _output = torch.gather(_input, 0, _index) print("_output:\n{}\n".format(_output))
- Similar to dim=1, it can be understood as replacing each column
- The input requirement is also consistent with dim=1
three-dimensional
import torch _index = torch.trunc(torch.rand(2, 2, 8) * 10).long() _input = torch.tensor( [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]) print("_index:\n{}\n".format(_index)) _input = _input.expand(2, 2, 10) print("_input:\n{}\n".format(_input)) _output = torch.gather(_input, 2, _index) print("_output:\n{}\n".format(_output))
- Similar to one and two dimensions
Some understanding of gather
-
input requirements are:
The dimension is consistent with index, and the length of each dimension must be the same as the corresponding dimension of index except the dimension specified by dim.
The dimension length specified by dim should be greater than the maximum value of index element.
-
Whether each row / column of input is the same:
Although you can customize the input to distinguish between different rows / columns, so as to change the content replaced by index. However, this is of little significance, and it will be troublesome to customize input in high dimensions. Personally, the most commonly used is to let input replace all indexes to be replaced.
-
About gather:
The use of gather can be understood as the replacement of Chinese characters and Pinyin. One dimensional represents the replacement of a sentence, two-dimensional represents the replacement of an article, and three-dimensional represents the replacement of a book. For example: [I, love, you] → [wo, ai, ni].
summary
The functions of where and gather can be realized by writing code by ourselves, but the bottom layer of where and gather is better than the code written by ourselves, so the speed is much faster. So use where and gather as much as you can.