I. Prohibiting the calculation of local gradients
torch.autogard.no_grad: A context manager that disables gradient computation.
When it is determined that Tensor.backward() will not be invoked to calculate the gradient, setting a prohibition gradient reduces memory consumption. Set Tensor. Requirements_grad = True if you need to compute the gradient
Two ways to disable:
>>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... y = x * 2 >>> y.requires_grad False >>> @torch.no_grad() ... def doubler(x): ... return x * 2 >>> z = doubler(x) >>> z.requires_grad False
2. Allow local gradient calculation after prohibition
torch.autogard.enable_grad: Context Manager that allows gradient calculation
Enabling gradient calculation in a no_grad context. This context manager has no effect outside no_grad.
The usage is similar to the above:
>>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad True >>> y.backward() >>> x.grad >>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad True
III. Computing Gradient
torch.autograd.set_grad_enable()
It can be used as a function:
>>> x = torch.tensor([1.], requires_grad=True) >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> torch.set_grad_enabled(True) >>> y = x * 2 >>> y.requires_grad True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad False
Conclusion:
There is nothing to use these three functions alone, but if nested, follow the principle of proximity.
x = torch.tensor([1.], requires_grad=True) with torch.enable_grad(): torch.set_grad_enabled(False) y = x * 2 print(y.requires_grad) >>> False torch.set_grad_enabled(True) with torch.no_grad(): z = x * 2 print(z.requires_grad) >>> False
Reference resources: