From a833c53d733ae7c9b1d1c0f39c33fa8bc8234d76 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Wed, 25 Oct 2023 14:45:02 -0700 Subject: [PATCH] Test SD with updated Triton --- .github/workflows/nv-inference.yml | 2 +- .../ops/transformer/inference/triton_ops.py | 9 +----- requirements/requirements-sd.txt | 2 +- tests/unit/inference/test_inference.py | 31 +++++++++++++++++++ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index 065f8b93f1e0..80481eb6bf4a 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -39,7 +39,7 @@ jobs: - name: Install deepspeed run: | - pip install .[dev,1bit,autotuning,inf,triton] + pip install .[dev,1bit,autotuning,inf,triton,sd] ds_report - name: Python environment diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index 0c9c53ab1de1..56e98f72a07c 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -18,7 +18,6 @@ def _fwd_kernel( K, V, sm_scale, - TMP, Out, stride_qz, stride_qh, @@ -57,7 +56,6 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v # initialize pointer to m and l - t_ptrs = TMP + off_hz * N_CTX + offs_m m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -69,8 +67,7 @@ def _fwd_kernel( # -- compute qk ---- k = tl.load(k_ptrs + start_n * stride_kn) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) + qk = tl.dot(q, tl.trans(k)) qk *= sm_scale # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -87,8 +84,6 @@ def _fwd_kernel( p = p * p_scale[:, None] # scale acc acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load acc = acc * acc_scale[:, None] # update acc v = tl.load(v_ptrs + start_n * stride_vk) @@ -115,7 +110,6 @@ def forward(self, q, k, v, sm_scale, block_128=True): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) - tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( @@ -123,7 +117,6 @@ def forward(self, q, k, v, sm_scale, block_128=True): k, v, sm_scale, - tmp, o, q.stride(0), q.stride(1), diff --git a/requirements/requirements-sd.txt b/requirements/requirements-sd.txt index 7b988876f54d..086a8e3f4879 100644 --- a/requirements/requirements-sd.txt +++ b/requirements/requirements-sd.txt @@ -1,2 +1,2 @@ diffusers -triton==2.0.0.dev20221202 +triton diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 894f040be207..1d9167b95bac 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -341,6 +341,37 @@ def test( assert assert_fn(bs_output, ds_output) +# Setup for these models is different from other pipelines, so we add a separate test +@pytest.mark.inference +class TestStableDiffusion(DistributedTest): + world_size = 1 + + def test(self): + from diffusers import DiffusionPipeline + + prompt = "a dog on a rocket" + model = "prompthero/midjourney-v4-diffusion" + local_rank = int(os.getenv("LOCAL_RANK", "0")) + device = torch.device(f"cuda:{local_rank}") + + pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half) + pipe = pipe.to(device) + baseline_image = pipe(prompt, guidance_scale=7.5).images[0] + + pipe = deepspeed.init_inference( + pipe, + mp_size=1, + dtype=torch.half, + replace_method="auto", + replace_with_kernel_inject=True, + enable_cuda_graph=False, + ) + deepspeed_image = pipe(prompt, guidance_scale=7.5).images[0] + + # Need to determine a heuristic for checking if images are "similar" + #assert baseline_image == deepspeed_image + + @pytest.mark.seq_inference @pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"), ("EleutherAI/gpt-neox-20b", "text-generation"),