diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 57a7ac68fc..268c46c85c 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: def create_joint_forward_backward(fn): def joint_forward_backward( - primals: List[Any], tangents: List[Any] + primals: List[Any], cotangents: List[Any] ) -> Tuple[List[Any], List[Any]]: # Call the forward pass outs = fn(*primals) @@ -68,20 +68,20 @@ def joint_forward_backward( grad_primals.append(p) # Get the outputs that need gradients - assert len(tangents) == len(outs) + assert len(cotangents) == len(outs) needed_outs = [] - needed_tangents = [] - for out, tangent in zip(outs, tangents): + needed_cotangents = [] + for out, cotangent in zip(outs, cotangents): if isinstance(out, Tensor) and out.requires_grad: needed_outs.append(out) - needed_tangents.append(tangent) + needed_cotangents.append(cotangent) backward_out = [] # Call the backwards pass if grad_primals: backward_out = torch.autograd.grad( needed_outs, grad_primals, - grad_outputs=needed_tangents, + grad_outputs=needed_cotangents, allow_unused=True, ) backward_out_iter = iter(backward_out) @@ -140,12 +140,14 @@ def create_aot_autograd_function( compiled_fw = None compiled_bw = None num_outs = None - + joint_inputs = None + fw_outs = None + aot_decompositions = {**aot_autograd_decompositions, **decompositions} class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, compiled_bw, num_outs + nonlocal compiled_fw, num_outs, joint_inputs, fw_outs if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -159,19 +161,19 @@ def forward(ctx, *flat_tensor_args): num_outs = 1 joint_inputs = (flat_tensor_args, out) - aot_decompositions = {**aot_autograd_decompositions, **decompositions} + # Need it because autograd.Function disables grad in forward with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)( *joint_inputs ) fw_module, bw_module = partition_fn(fx_g, joint_inputs) - # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - - bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] - compiled_bw = bw_compiler(bw_module, bw_args) + if partition_fn is default_partition: + nonlocal compiled_bw + bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) ctx.save_for_backward(*fw_outs[num_outs:]) @@ -179,9 +181,14 @@ def forward(ctx, *flat_tensor_args): @staticmethod @disable_torchdynamo - def backward(ctx, *flat_args): - contiguous_args = [t.contiguous() for t in flat_args] - # contiguous_args = [t for t in flat_args] + def backward(ctx, *flat_grad_outs): + nonlocal compiled_bw + contiguous_args = [t.contiguous() for t in flat_grad_outs] + if compiled_bw is None: + with torch.set_grad_enabled(grad_state): + fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args) + fw_module, bw_module = partition_fn(fx_g, joint_inputs) + compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args) out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) return tuple(out)