autodiff
Concepts


-
x.is_leafwhether tensor
xis a leaf node, i.e.,x.grad_fn == None(is not created via aFunctionoperation / is the leaf nodes in the DAG).In specific,
is_leaf == Trueincludes the following conditions:requires_grad == False(e.g., the input data, the ground truth)requires_grad == True and grad_fn == None(e.g., the trainable parameters)
therefore, non-leaf tensors are any outputs of
Functionthatrequires_grad == True. (e.g., predictions, loss).non-leaf tensors must have
requires_grad = True, if you want to change it, an error will be reported:x = torch.rand(1, 3) y = nn.Linear(3, 5)(x) y.requires_grad = False # RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().
-
x.requires_gradwhether tensor
xrequires grad, which affects:- In
forward(), aFunctionis only recored in the DAG if at least one of its inputs hasrequires_grad == True - In
backward(), only tensors withis_leaf == True and requires_grad == Truewill getgrad.
default to
False, unless set explicitly, or wrapped innn.Parameter.it is just a boolean flag, so setting it is not a
Function.# set at intialization x = torch.rand(3, requires_grad=True) # both following method just change the flag in-place. x.requires_grad = True x.requires_grad_(True) # also return the tensor itself. # nn.Parameter (which is used inside nn.Module) automatically requires grad. x = nn.Parameter(torch.rand(3)) # is_leaf: True, requires_grad: True, grad_fn: None, grad: None x = nn.Linear(3, 5).weight # is_leaf: True, requires_grad: True, grad_fn: None, grad: None - In
-
x.grad_fnthe
torch.autograd.Function(operation) that createsx.this is the place where the computation DAG is saved (i.e., saves all the inputs, and the exact operator used to compute
x)
-
x.backward()calculate gradients through the DAG.
only non-leaf tensors (
grad_fn != None) can callbackward(), else an error will be reported:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
-
x.gradthe accumulated / populated gradients, default to
None.only tensor with
is_leaf == True and requires_grad == Truewill get itsgradafter a laterbackward().if you access the
gradof a non-leaf tensor, a warning will be reported:UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
-
x.detach()create a new tensor (but not copy data) from
x, and setrequires_grad = False. (which also meansis_leaf == True). note thatx.clone()will copy data but remains the gradient to the original tensor (thusis_leaf == False).
-
with torch.no_grad():temporarily disable all DAG recording (and thus cannot call
backward()), i.e., behave as if no inputs requires grad.it can also work as a decorator:
@torch.no_grad() def func(x): ...torch.enable_grad()can be used insidetorch.no_grad()block, to re-enable gradient calculation.
-
nn.Module.eval()not to be confused with
no_grad(), it just sets the correct behavior for special modules likeBatchNormandDropout.the correct way to perform inference is always using both
no_grad()andeval():model.eval() with torch.no_grad(): y = model(x)
-
with torch.set_grad_enabled(True/False):similar to
with torch.no_grad()/enable_grad():.
-
x.retain_grad()xmust be a non-leaf tensor to call this beforebackward(), soxwill keep itsgrad(as the leaf tensors that requires grad).the state can be checked by
x.retains_grad, which is a boolean flag.
-
torch.autograd.grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False)a way to manually compute and get the gradients.
- outputs: tensor, usually a scalar such as the loss.
- inputs: tensor, used to calculate \(\frac{d~\text{outputs}}{d~\text{inputs}}\).
- grad_outputs: \(d~\text{outputs}\), should be the same shape as outputs, if left
None, will use all ones. - create_graph: whether keep the DAG so higher order derivitives can be calculated.
- retrain_graph: default to create_graph, do not free the DAG.
example use in SDF to calculate normals:
with torch.set_grad_enabled(True): input.requires_grad_(True) sigma = self.backbone(input) normal = - torch.autograd.grad(torch.sum(sigma), input, create_graph=True)[0] # [B, N, 3] return sigma, normal
Full Examples:
import torch
import torch.nn as nn
def check(x):
print(f'is_leaf: {x.is_leaf}, requires_grad: {x.requires_grad}, grad_fn: {x.grad_fn}, grad: {x.grad}')
#####################
### simple operations
x = torch.rand(3)
check(x)
# is_leaf: True, requires_grad: False, grad_fn: None, grad: None
x = torch.rand(3, requires_grad=True)
check(x)
# is_leaf: True, requires_grad: True, grad_fn: None, grad: None
x = torch.rand(3, requires_grad=True) + 1
check(x)
# is_leaf: False, requires_grad: True, grad_fn: <AddBackward0 object at 0x7f7df1added0>, grad: None
# note: + is a Function. therefore, it leaves a CopyBackwards grad_fn, and makes x non-leaf.
x = torch.rand(3, requires_grad=True).cuda()
check(x)
# is_leaf: False, requires_grad: True, grad_fn: <CopyBackwards object at 0x7f7e2efe90d0>, grad: None
# note: cuda() is also a Function!
x = torch.rand(3) + 1
check(x)
# is_leaf: True, requires_grad: False, grad_fn: None, grad: None
# note: if x doesn't require grad, the Function that operates on it will not be recorded, and the output is still a leaf.
x = torch.rand(3).cuda()
check(x)
# is_leaf: True, requires_grad: False, grad_fn: None, grad: None
##############
### nn.Module
x = torch.rand(3, requires_grad=True, device='cuda')
check(x)
# cuda:0, is_leaf: True, requires_grad: True, grad_fn: None, grad: None
lin = nn.Linear(3, 5)
check(lin.weight)
# is_leaf: True, requires_grad: True, grad_fn: None, grad: None
# note: nn.Module's parameters are wrapped by nn.Parameter, so it is leaf and requires grad.
x = torch.rand(1, 3)
y = lin(x)
check(y)
# is_leaf: False, requires_grad: True, grad_fn: <AddmmBackward object at 0x7f7e4004c9d0>, grad: None
gt = torch.rand(1, 5)
l = nn.L1Loss(y, gt)
check(l)
# is_leaf: False, requires_grad: True, grad_fn: <L1LossBackward object at 0x7f7e40021c90>, grad: None
l.backward()
check(lin.weight)
# is_leaf: True, requires_grad: True, grad_fn: None, grad: tensor([...])
# note: other tensors (x, y, l, gt) will still have grad == None.