Skip to content

Commit

Permalink
Separate forward and backwad compilation for default partition
Browse files Browse the repository at this point in the history
ghstack-source-id: 4de63f2aff78e0575fc342e13688308c542aa62f
Pull Request resolved: #856
  • Loading branch information
anjali411 committed Jun 7, 2022
1 parent 130582c commit 4cc60ae
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:

def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
primals: List[Any], cotangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
Expand All @@ -68,20 +68,20 @@ def joint_forward_backward(
grad_primals.append(p)

# Get the outputs that need gradients
assert len(tangents) == len(outs)
assert len(cotangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
needed_cotangents = []
for out, cotangent in zip(outs, cotangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
needed_cotangents.append(cotangent)
backward_out = []
# Call the backwards pass
if grad_primals:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
grad_outputs=needed_cotangents,
allow_unused=True,
)
backward_out_iter = iter(backward_out)
Expand Down Expand Up @@ -140,12 +140,14 @@ def create_aot_autograd_function(
compiled_fw = None
compiled_bw = None
num_outs = None

joint_inputs = None
fw_outs = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
nonlocal compiled_fw, num_outs, joint_inputs, fw_outs
if compiled_fw is None:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
Expand All @@ -159,29 +161,34 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
# Need it because autograd.Function disables grad in forward
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
if partition_fn is default_partition:
nonlocal compiled_bw
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
def backward(ctx, *flat_grad_outs):
nonlocal compiled_bw
contiguous_args = [t.contiguous() for t in flat_grad_outs]
if compiled_bw is None:
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)

Expand Down

0 comments on commit 4cc60ae

Please sign in to comment.