Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prims.copy_to_out_ #1194

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

shino16
Copy link
Contributor

@shino16 shino16 commented Sep 24, 2024

Fixes #1173. This adds a new primitive prims.copy_to_out_(computed, *, out), which is used instead of prims.copy_ for the update in in-place ops. Unlike prims.copy_, prims.copy_to_out_(computed, *, out) assumes that computed is not used by subsequent ops, so out can simply alias computed.

main: 63887b3
#1193: 2781a20

compilation (s) execution (ms)
eager 0.0 10.86
torch.compile(adam.step, backend=thunder) main 94.0 11.52
torch.compile(adam.step, backend=thunder), #1193 48.0 6.00

The rule with prims.copy_to_out_ is (link):

# WARN: `computed` must be an intermediate tensor used solely for this `copy_to_out_` call,
# e.g. copy_to_out_(add(a, b), out=a). Thunder does not guarantee that `computed` remains to have
# the correct values after copy_to_out_ returns. For general-purpose copy, use prims.copy_ instead

This rule comes from the fact that any copies onto out will be propagated to its alias, computed.

To prevent users from using prims.copy_to_out_ inappropriately, I made the sanity check on prims.copy_to_out_ rather conservative. When enabled, it raises an error when computed is

  • used as an input to another ops within the nvFuser region
  • defined outside of the region (because it may be modified in-place), or
  • used outside of the region (because it may not have the correct value).

See tests for examples.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shino16 .
I think the overall approach is good and to my mind, we can merge when it fully works, but I want to let @crcrpar and/or @IvanYashchuk to take a look.

There seem to be a few cases to look at, though:

FAILED thunder/tests/test_inplace_functionalization.py::test_inplace_to_alias_func_args_nvfuser_cuda_thunder.dtypes.float32 - NotImplementedError: <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))> (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. Copies onto <TensorProxy(name="a", dtype=thunder.dtypes.float32, shape=(2, 2))> or None in the region may propagate to <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_inplace_functionalization.py::test_inplace_copy_on_fusion_inputs_issue_791_nvfuser_cuda_None - NotImplementedError: <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="x", dtype=thunder.dtypes.float32, shape=(2, 2))> or <TensorProxy(name="t1", dtype=thunder.dtypes.float32, shape=(2, 2))> in the region may propagate to <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_inplace_functionalization.py::test_multiple_inplace_to_multiple_args_nvfuser_cuda_None - NotImplementedError: <TensorProxy(name="t12", dtype=thunder.dtypes.float32, shape=(2, 2))> (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. Copies onto <TensorProxy(name="t_1_1", dtype=thunder.dtypes.float32, shape=(2, 2))> or None in the region may propagate to <TensorProxy(name="t12", dtype=thunder.dtypes.float32, shape=(2, 2))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_inplace_functionalization.py::test_single_tensor_adam_like_nvfuser_cuda_None - NotImplementedError: <TensorProxy(name="t2", dtype=thunder.dtypes.float32, shape=(4,))> (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. Copies onto <TensorProxy(name="exp_avg", dtype=thunder.dtypes.float32, shape=(4,))> or None in the region may propagate to <TensorProxy(name="t2", dtype=thunder.dtypes.float32, shape=(4,))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-long-context-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama1-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama2-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-falcon-7b-like] - NotImplementedError: <TensorProxy(name="t85", dtype=thunder.dtypes.float32, shape=(1, 1, 3, 64))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 1, 3, 64))> or None in the region may propagate to <TensorProxy(name="t85", dtype=thunder.dtypes.float32, shape=(1, 1, 3, 64))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-falcon-40b-like] - NotImplementedError: <TensorProxy(name="t87", dtype=thunder.dtypes.float32, shape=(1, 64, 3, 4))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 64, 3, 4))> or None in the region may propagate to <TensorProxy(name="t87", dtype=thunder.dtypes.float32, shape=(1, 64, 3, 4))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-codellama2-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you educate me why would we want to prefer clang.copy_to_out_ to prims.copy_to_out_?
To allow out having a different dtype from computed? If so, why wouldn't the prim allow it?

Copy link
Collaborator

@t-vi t-vi Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Traditionally, we have resolved typing before the prim level when decomposing, so having clang.copy_to_out_ decompose to (optionally) prims.convert_element_type plus prims.copy_out_to_ seems matching the patterns we have.

Comment on lines +557 to +563
for copy_bsyms in bsym_to_copy_bsyms[bsym]:
functionalized_bsyms.extend(copy_bsyms)
copy_bsym = functionalized_bsyms[-1]
# wrap_return_value_together_with_argments places all the arguments in the return value
# We swap these arguments in the return value with the outputs of copies onto them
# This prevents subsequent transforms from ordering the return statement before those copies
swap_map_for_return[variableify(copy_bsym.flat_proxy_args[0])] = copy_bsym.flat_proxy_outs[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the changes in this file are just variable renaming, except this loop and line 547 brought in commit 47f35d9.

This fixes a bug which caused some of the test failures mentioned in #1194 (review). When bsym (key_bsym in line 547) is associated with multiple copy_bsyms, bsym_to_copy_bsyms[bsym] previously looked like [(reshape,) copy, (reshape,) copy, ...]. Now it looks like [[(reshape,) copy], [(reshape,) copy], ...], and we iterate through all copies.

@shino16 shino16 marked this pull request as draft September 25, 2024 08:09
@shino16
Copy link
Contributor Author

shino16 commented Sep 25, 2024

As of now, the test thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama1-like] does NOT pass.

Minimal reproducible example:

import torch
import thunder

@partial(thunder.jit, disable_inplace_copy_check=True)
def f(q, k, v, mask, idx, src):
    q.index_copy_(2, idx, src)
    k.index_copy_(2, idx, src)
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)

q = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)
k = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
v = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
mask = torch.ones((1, 1, 2, 3), device='cuda', dtype=torch.bool)
idx = torch.arange(2).to(device='cuda')
src = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)

f(q, k, v, mask, idx, src)

Execution trace:

def computation(q, k, v, mask, idx, src):
  # q: "cuda:0 f32[1, 4, 2, 16]"
  # k: "cuda:0 f32[1, 4, 3, 16]"
  # v: "cuda:0 f32[1, 4, 3, 16]"
  # mask: "cuda:0 b8[1, 1, 2, 3]"
  # idx: "cuda:0 i64[2]"
  # src: "cuda:0 f32[1, 4, 2, 16]"
  t0 = torch.index_copy(q, 2, idx, src)  # t0: "cuda:0 f32[1, 4, 2, 16]"
    # t0 = ltorch.index_copy(q, 2, idx, src)  # t0: "cuda:0 f32[1, 4, 2, 16]"
      # t0 = prims.index_copy(q, idx, src, 2)  # t0: "cuda:0 f32[1, 4, 2, 16]"
  t2 = torch.index_copy(k, 2, idx, src)  # t2: "cuda:0 f32[1, 4, 3, 16]"
    # t2 = ltorch.index_copy(k, 2, idx, src)  # t2: "cuda:0 f32[1, 4, 3, 16]"
      # t2 = prims.index_copy(k, idx, src, 2)  # t2: "cuda:0 f32[1, 4, 3, 16]"
  [t1, t3] = nvFusion0(t0, q, t2, k)
    # t1 = prims.copy_to_out_(t0, out=q)  # t1: "cuda:0 f32[1, 4, 2, 16]"
    # t3 = prims.copy_to_out_(t2, out=k)  # t3: "cuda:0 f32[1, 4, 3, 16]"
  del q, k
  (t19, _, _, _) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t0, t2, v, mask, 0.0, False, None)
  del t0, t2
  return t19

_inplace_copy_sanity_check raises an error for this, because

  • t0 is passed to the sdpa operator, and
  • if nvFusion0 had a copy onto q in the form of prims.copy_(XX, copy_to=q), XX is propagated to t0.

Note that, before passing the trace to the nvFuser executor, prims.copy_to_out_(t0, out=q) is put after the sdpaex operator thanks to functionalization.

Trace just before nvFuser
def computation(q, k, v, mask, idx, src):
  # q: "cuda:0 f32[1, 4, 2, 16]"
  # k: "cuda:0 f32[1, 4, 3, 16]"
  # v: "cuda:0 f32[1, 4, 3, 16]"
  # mask: "cuda:0 b8[1, 1, 2, 3]"
  # idx: "cuda:0 i64[2]"
  # src: "cuda:0 f32[1, 4, 2, 16]"
  # Functionalized from `t1 = index_copy_(q,2,idx,src)`
  t0 = ltorch.index_copy(q, 2, idx, src)  # t0: "cuda:0 f32[1, 4, 2, 16]"
    # t0 = prims.index_copy(q, idx, src, 2)  # t0: "cuda:0 f32[1, 4, 2, 16]"
  # Functionalized from `t3 = index_copy_(k,2,idx,src)`
  t2 = ltorch.index_copy(k, 2, idx, src)  # t2: "cuda:0 f32[1, 4, 3, 16]"
    # t2 = prims.index_copy(k, idx, src, 2)  # t2: "cuda:0 f32[1, 4, 3, 16]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60:          return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
  # ['t1', 't3'] are replaced by ['t0', 't2'], respectively
  t19 = ltorch.scaled_dot_product_attention(t0, t2, v, mask, 0.0, False, scale=None)  # t19: "cuda:0 f32[1, 4, 2, 16]"
    # subsymbols omitted
  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:58:          q.index_copy_(2, idx, src)
  t1 = prims.copy_to_out_(t0, out=q)  # t1: "cuda:0 f32[1, 4, 2, 16]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:59:          k.index_copy_(2, idx, src)
  t3 = prims.copy_to_out_(t2, out=k)  # t3: "cuda:0 f32[1, 4, 3, 16]"

  # /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60:          return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
  return {'output': t19, 'flat_args': [t1, t3, v, mask, idx, src]}

Possible solutions

  • Functionalize prims.copy_ in a way such that there will never be multiple copies onto the same tensor. If we assume this, we can be sure that q never changes after prims.copy_to_out_(t0, out=q) in the previous example.
  • Preserve order between operations involving t0 and prims.copy_to_out_(t0, out=q). This order is enforced by functionalization (link), but they do not establish dependency relationship in terms of outputs and inputs. This would lead to exposing fixes only for nvFuser.

@shino16
Copy link
Contributor Author

shino16 commented Sep 25, 2024

I disabled the sanity check for test_litgpt_variants_kvcache.

@shino16 shino16 marked this pull request as ready for review September 25, 2024 11:03
@t-vi
Copy link
Collaborator

t-vi commented Sep 25, 2024

Note that, before passing the trace to the nvFuser executor, prims.copy_to_out_(t0, out=q) is put after the sdpaex operator thanks to functionalization.

Ugh.

Could it be that inplace_copy_ is particular here?

I wonder if something along the lines of @IvanYashchuk 's planned primitive for dataflow healing would be useful.

@shino16
Copy link
Contributor Author

shino16 commented Sep 25, 2024

The same happens when we use Tensor.add_ instead of Tensor.index_copy_ too, so this is a generic issue.

My internship period is about to end, so I can no longer spend much time on this issue. Maybe we can close this PR for now and wait for functionalization of copy_ or dependency establishment to mature, and merge #1177 or take other strategies in the meantime if the performance regression in #1173 is urgent.

@shino16 shino16 mentioned this pull request Sep 27, 2024
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PR #1110 nearly doubles the compilation & execution time of a copy-heavy program
3 participants