Skip to content

Commit

Permalink
Merge branch 'main' into inplace-update-prim
Browse files Browse the repository at this point in the history
  • Loading branch information
shino16 authored Sep 24, 2024
2 parents eb61e6a + a92ed64 commit ffbca34
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 20 deletions.
18 changes: 11 additions & 7 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,6 @@ def __init__(
self.global_batch_size % self.micro_batch_size * world_size == 0
), f"Global Batch Size {self.global_batch_size} should be a multiple Micro Batch Size {self.micro_batch_size} * World Size {world_size}."

if self.checkpoint_activations and "thunder" in self.compile:
warnings.warn(
"Activations checkpointing is configured, but Thunder does not support checkpointing. Checkpointing will be ignored."
)
self.checkpoint_activations = False
self.skip_data_sync = skip_data_sync

# Profiling Args
Expand Down Expand Up @@ -534,6 +529,10 @@ def setup_distributed(self, model):
return model

def setup_activation_checkpointing(self):
if "thunder" in self.compile:
# checkpointing is an option to thunder.jit
return

if any(isinstance(mod, CheckpointWrapper) for mod in self.model.modules()):
warnings.warn(
"FSDP checkpointing is configured, but the model already contains checkpointed layers."
Expand Down Expand Up @@ -569,6 +568,11 @@ def setup_compile(self, model):

executors.insert(0, transformer_engine_ex)

jit_options = {
"enable_saved_for_backward_recomputation": self.checkpoint_activations,
"recomputation_policy": None,
}

if "dynamo" in self.compile:
if self.distributed_mode == "fsdp2":
print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
Expand All @@ -588,14 +592,14 @@ def setup_compile(self, model):
raise ValueError(
"TransformerEngine executor cannot be used as an executor of Thunder when Thunder is used as torch.compile backend"
)
backend = ThunderCompiler(executors=executors)
backend = ThunderCompiler(executors=executors, **jit_options)
# Because Lightning Fabric is imported in this script it monkey patches the torch.compile function
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
# using __wrapped__ to access the original torch.compile function did not work
# so we are using the lower level torch._dynamo.optimize function
model = torch._dynamo.optimize(backend=backend)(model)
else:
model = thunder.jit(model, executors=executors)
model = thunder.jit(model, executors=executors, **jit_options)

elif self.compile != "eager":
raise ValueError(f"Invalid compile option: {self.compile}")
Expand Down
68 changes: 57 additions & 11 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,12 @@ def is_jitting_with_raise():

# Guard against opaque functions which interrupt jitting.
if (ctx := get_interpretercompilectx_if_available()) is not None:
raise InterpreterError(f"Lookaside was not triggered, but there is an active compile context: {ctx}")
# nested try to delete ctx from locals
try:
raise InterpreterError(f"Lookaside was not triggered, but there is an active compile context: {ctx}")
except InterpreterError:
del ctx
raise

return False

Expand Down Expand Up @@ -1495,8 +1500,9 @@ def set_builtins(globals, builtins_dict):
except Exception as e:
# We need to cheat a bit to get a Python frame here...
python_frame = frame.get_or_make_python_frame()
tb = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno)
raise e.with_traceback(tb)
e.__traceback__ = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno)
del e
raise # re-raises e

if mode == "eval":
return res
Expand Down Expand Up @@ -6254,14 +6260,24 @@ def thunder_interpreter_generator():
res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value)
except Exception as e:
msg = f"Encountered exception {type(e).__name__}: {e}"
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise
if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
e = runtimectx.curexc
assert isinstance(e, BaseException)
runtimectx.curexc = None
if isinstance(e, StopIteration):
return unwrap(e.value)
raise e
# nested try except to delete e from locals
try:
raise e
except BaseException:
del e
raise
if status == INTERPRETER_SIGNALS.RETURN_VALUE:
return # TODO: should this return res?
assert status == INTERPRETER_SIGNALS.YIELD_VALUE
Expand All @@ -6284,14 +6300,24 @@ async def thunder_interpreter_async_generator():
res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value)
except Exception as e:
msg = f"Encountered exception {type(e).__name__}: {e}"
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise
if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
e = runtimectx.curexc
assert isinstance(e, BaseException)
runtimectx.curexc = None
if isinstance(e, StopIteration):
return
raise e
# nested try except to delete e from locals
try:
raise e
except BaseException:
del e
raise
if status == INTERPRETER_SIGNALS.RETURN_VALUE:
return # TODO: should this return res?
assert status == INTERPRETER_SIGNALS.YIELD_VALUE
Expand All @@ -6314,14 +6340,24 @@ async def thunder_interpreter_coroutine():
res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value)
except Exception as e:
msg = f"Encountered exception {type(e).__name__}: {e}"
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise
if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
e = runtimectx.curexc
assert isinstance(e, BaseException)
runtimectx.curexc = None
if isinstance(e, StopIteration):
return unwrap(e.value)
raise e
# nested try except to delete e from locals
try:
raise e
except BaseException:
del e
raise
if status == INTERPRETER_SIGNALS.RETURN_VALUE:
return unwrap(res)
assert status == INTERPRETER_SIGNALS.YIELD_VALUE
Expand Down Expand Up @@ -7134,7 +7170,12 @@ def fn_2(args, kwargs):
msg = (
f"Encountered exception {type(e).__name__}: {e} while tracing {fn}:{os.linesep}" f"{traceback_str}"
)
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise

# NOTE: Wrapped functions are valid to assign new attributes to.
fn_._last_interpreter_log = runtimectx.interp_log # type: ignore
Expand All @@ -7143,7 +7184,12 @@ def fn_2(args, kwargs):
e = runtimectx.curexc
assert isinstance(e, BaseException), e
runtimectx.curexc = None
raise e
# The below is "raise e" but deleting e from the scope
try:
raise e
except Exception:
del e
raise

return interpretation_result

Expand Down
103 changes: 101 additions & 2 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@
trace_interpreter_skip_list,
)
from thunder.core.proxies import (
CollectionProxy,
NumberProxy,
Proxy,
TensorProxy,
FloatProxy,
variableify,
unvariableify,
FutureTensorProxy,
)
from thunder.core.compile_data import get_compile_data
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol
Expand All @@ -57,6 +59,7 @@
const_as,
sequencify,
ProxyDict,
find_producer_symbols,
)
import thunder.clang as clang
from thunder.clang import (
Expand All @@ -76,7 +79,7 @@
unwrap_return_value,
VJPDual,
)
from thunder.core.vjp_utils import make_aug_forward_and_backward
from thunder.core.vjp_utils import make_aug_forward_and_backward, get_saved_for_backward_tensors
from thunder.extend import Executor
import thunder.torch as ltorch

Expand Down Expand Up @@ -3075,4 +3078,100 @@ def backward_fn(saved_for_backward, cotangents):
_update_backward_with_new_saved_for_backward(backward_trace, only_used_bw_saved_for_backward)
forward_trace.set_provenance(TraceProvenance("Augmented forward pass"))
backward_trace.set_provenance(TraceProvenance("Backward pass"))

enable_saved_for_backward_recomputation: None | bool = get_compile_option(
"enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation."
)
if enable_saved_for_backward_recomputation:
forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace)

return ForwardBackwardTraces(forward_trace, backward_trace)


def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Trace, Trace]:
"""Generates the pair of traces with rematerializaion of the saved-for-backward tensors.
Args:
fwd_trace (Trace): forward trace where to get the saved for backward from.
bwd_trace (Trace): backward trace where to recompute the saved for backward to.
Returns:
tuple[Trace, Trace]: A tuple containing the new forward and backward traces.
"""

start_time_ns = time.perf_counter_ns()

saved_for_bw = get_saved_for_backward_tensors(fwd_trace)
fwd_trace_args = {variableify(j) for j in fwd_trace.args}
old_saved_for_bwd = {variableify(j) for j in saved_for_bw}

all_rematerializable = old_saved_for_bwd - fwd_trace_args

remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option(
"recomputation_policy",
"A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.",
)

if remat_policy:
rematerializable = remat_policy(all_rematerializable)
else:
rematerializable = all_rematerializable

producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args)

required_fw_args = fwd_trace_args & old_saved_for_bwd
recomputed_tensors_from_producers = set()
for prod in producers:
for prod_arg in prod.flat_args:
prod_arg = variableify(prod_arg)
if prod_arg in fwd_trace_args:
required_fw_args.add(prod_arg)
for prod_out in prod.flat_outs:
recomputed_tensors_from_producers.add(variableify(prod_out))

required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers
new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd)

new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()
new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1]))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=())

new_bwd_trace = from_trace(bwd_trace)
# In cases where C0 name is carried from previous trace it must be removed
# as the proxy needs to register with that specific name to follow the backward
# trace standard signature.
new_bwd_trace.names.discard("C0")

with tracectx(new_bwd_trace):
unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward))

# Here we make sure that the signature of the backward trace is the same as the one we expect.
# This part of the trace is the unpacking of the tuple passed from the forward trace,
# more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents
# used to compute the vector-Jacobian product.
assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[4].args[0].name == "C0"
assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[5].args[0].name == "C1"

for idx, bsym in enumerate(bwd_trace.bound_symbols):
if idx == 4:
new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward)
new_bwd_trace.bound_symbols.append(new_unpack)
elif idx == 6:
new_bwd_trace.bound_symbols.extend(producers)
new_bwd_trace.bound_symbols.append(bsym)
else:
new_bwd_trace.bound_symbols.append(bsym)

new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]]

elapsed_time_ns = time.perf_counter_ns() - start_time_ns
new_bwd_trace.set_provenance(
TraceProvenance(f"Saved for backward remat trace (took {elapsed_time_ns * 1e-6:.2f} milliseconds)")
)
new_fwd_trace.set_provenance(
TraceProvenance(f"Saved for backward remat trace (took {elapsed_time_ns * 1e-6:.2f} milliseconds)")
)

return new_fwd_trace, new_bwd_trace
23 changes: 23 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,29 @@ def foo():
assert weak_x() is None


def test_uncaught_exception_no_leak():

class Identity(torch.nn.Module):
def forward(self, x):
raise RuntimeError("FOOBAR")
return x

def main():
with torch.device("cpu"):
model = thunder.jit(Identity())
x = torch.randn(16, 16)

try:
model(x)
except:
pass
return weakref.ref(x)

weak_x = main()

assert weak_x() is None


def test_walrus_operator(jit):
def foo(a, b):
c = (a := b)
Expand Down
Loading

0 comments on commit ffbca34

Please sign in to comment.