pytorch forbids / permits the calculation of local gradients

Posted by hellonoko on Fri, 04 Oct 2019 10:57:50 +0200

I. Prohibiting the calculation of local gradients

torch.autogard.no_grad: A context manager that disables gradient computation.

When it is determined that Tensor.backward() will not be invoked to calculate the gradient, setting a prohibition gradient reduces memory consumption. Set Tensor. Requirements_grad = True if you need to compute the gradient

Two ways to disable:

>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False


>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False

2. Allow local gradient calculation after prohibition

torch.autogard.enable_grad: Context Manager that allows gradient calculation

Enabling gradient calculation in a no_grad context. This context manager has no effect outside no_grad.

The usage is similar to the above:

>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...   with torch.enable_grad():
...     y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
>>> @torch.enable_grad()
... def doubler(x):
...     return x * 2
>>> with torch.no_grad():
...     z = doubler(x)
>>> z.requires_grad
True

III. Computing Gradient

torch.autograd.set_grad_enable()

It can be used as a function:

>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

Conclusion:

There is nothing to use these three functions alone, but if nested, follow the principle of proximity.

x = torch.tensor([1.], requires_grad=True)

with torch.enable_grad():
    torch.set_grad_enabled(False)
    y = x * 2
    print(y.requires_grad)
>>> False

torch.set_grad_enabled(True)
with torch.no_grad():
    z = x * 2
    print(z.requires_grad)
>>> False

Reference resources:

https://pytorch.org/docs/stable/autograd.html