From 6802a0c4e9868041aa825f629c5e983df96e3cab Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 29 Oct 2024 16:56:28 +0100 Subject: [PATCH] Add transformers 4.46 compatiblity (#2078) * transformers 4.46 * setup * uupdate setup * fix t5 * update python (3.8 eol) * fix onnx test * fixed deberta, onnxruntime tests in series passing * fix bt * fixed t5_forward for real, because it's also used by blip-2 as well * fix Phi3 * fix opt * vision encoder decoder * fix setup * style * fix encoder decoder * fixed transformers branch * branch * allow 4.47 * remove patch * add opt * add test * fix OPT ONNX export and inference * add test * update setup * style * merge tests * update tes num beams * add test transformers version * add architectures depending on transformers * add warning * revert * update test generation length * style --------- Co-authored-by: IlyasMoutawwakil --- .github/workflows/check_code_quality.yml | 2 +- .github/workflows/test_benckmark.yml | 30 +- .github/workflows/test_cli.yml | 4 +- .github/workflows/test_export_onnx.yml | 44 +-- .github/workflows/test_export_onnx_cli.yml | 30 +- .../workflows/test_export_onnx_cli_timm.yml | 26 +- .github/workflows/test_export_onnx_timm.yml | 27 +- .github/workflows/test_exporters_common.yml | 2 +- .github/workflows/test_exporters_slow.yml | 2 +- .github/workflows/test_fx.yml | 2 +- .github/workflows/test_offline.yml | 2 +- .github/workflows/test_onnx.yml | 2 +- .github/workflows/test_onnxruntime.yml | 13 +- .github/workflows/test_onnxruntime_slow.yml | 2 +- .github/workflows/test_optimum_common.yml | 39 +-- .github/workflows/test_utils.yml | 2 +- optimum/bettertransformer/models/attention.py | 326 ++++++++++++------ .../models/decoder_models.py | 4 +- optimum/bettertransformer/transformation.py | 36 +- optimum/exporters/onnx/model_configs.py | 49 ++- optimum/exporters/onnx/model_patcher.py | 3 +- optimum/exporters/onnx/utils.py | 6 +- optimum/onnxruntime/modeling_decoder.py | 4 +- optimum/utils/__init__.py | 1 + optimum/utils/import_utils.py | 16 + setup.py | 24 +- tests/bettertransformer/test_audio.py | 20 +- tests/bettertransformer/test_common.py | 12 +- tests/bettertransformer/test_decoder.py | 8 +- tests/bettertransformer/test_encoder.py | 4 +- .../bettertransformer/test_encoder_decoder.py | 2 +- tests/bettertransformer/test_gpu.py | 4 +- tests/bettertransformer/testing_utils.py | 18 +- tests/onnx/test_onnx_export_custom_module.py | 17 +- tests/onnxruntime/test_modeling.py | 61 ++-- tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 36 files changed, 541 insertions(+), 304 deletions(-) diff --git a/.github/workflows/check_code_quality.yml b/.github/workflows/check_code_quality.yml index c429b706bf..861684cfa4 100644 --- a/.github/workflows/check_code_quality.yml +++ b/.github/workflows/check_code_quality.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_benckmark.yml b/.github/workflows/test_benckmark.yml index 7f7f2ace32..e859e845d6 100644 --- a/.github/workflows/test_benckmark.yml +++ b/.github/workflows/test_benckmark.yml @@ -4,9 +4,9 @@ name: Benchmark suite / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -17,20 +17,20 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install wheel - pip install .[tests,onnxruntime,benchmark] - - name: Test with unittest - run: | - python -m unittest discover --start-directory tests/benchmark --pattern 'test_*.py' + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install wheel + pip install .[tests,onnxruntime,benchmark] + - name: Test with unittest + run: | + python -m unittest discover --start-directory tests/benchmark --pattern 'test_*.py' diff --git a/.github/workflows/test_cli.yml b/.github/workflows/test_cli.yml index ecb19d23aa..2efab40aab 100644 --- a/.github/workflows/test_cli.yml +++ b/.github/workflows/test_cli.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04, macos-13] runs-on: ${{ matrix.os }} @@ -34,7 +34,7 @@ jobs: run: | pip install --upgrade pip pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[tests,exporters,exporters-tf] + pip install .[tests,exporters-tf] - name: Test with pytest run: | diff --git a/.github/workflows/test_export_onnx.yml b/.github/workflows/test_export_onnx.yml index 56ef674cb4..0cd19a1724 100644 --- a/.github/workflows/test_export_onnx.yml +++ b/.github/workflows/test_export_onnx.yml @@ -2,9 +2,9 @@ name: Exporters ONNX / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -15,27 +15,27 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies for pytorch export - run: | - pip install .[tests,exporters] - - name: Test with unittest - working-directory: tests - run: | - pytest exporters/onnx/test_onnx_*.py -s -n auto -m "not tensorflow_test and not timm_test" --durations=0 - - name: Install dependencies for tensorflow export - run: | - pip install .[tests,exporters-tf] - - name: Test with unittest - working-directory: tests - run: | - pytest exporters/onnx/test_onnx_*.py -n auto -m "tensorflow_test" -s --durations=0 + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for pytorch export + run: | + pip install .[tests,exporters] + - name: Test with unittest + working-directory: tests + run: | + pytest exporters/onnx/test_onnx_*.py -s -n auto -m "not tensorflow_test and not timm_test" --durations=0 + - name: Install dependencies for tensorflow export + run: | + pip install .[tests,exporters-tf] + - name: Test with unittest + working-directory: tests + run: | + pytest exporters/onnx/test_onnx_*.py -n auto -m "tensorflow_test" -s --durations=0 diff --git a/.github/workflows/test_export_onnx_cli.yml b/.github/workflows/test_export_onnx_cli.yml index 8fa4ebb045..618a140c14 100644 --- a/.github/workflows/test_export_onnx_cli.yml +++ b/.github/workflows/test_export_onnx_cli.yml @@ -2,9 +2,9 @@ name: Exporters ONNX CLI / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -15,20 +15,20 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies for pytorch export - run: | - pip install .[tests,exporters] - - name: Test with unittest - working-directory: tests - run: | - pytest exporters/onnx/test_exporters_onnx_cli.py -n auto -m "not tensorflow_test and not timm_test" -s --durations=0 + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for pytorch export + run: | + pip install .[tests,exporters] + - name: Test with unittest + working-directory: tests + run: | + pytest exporters/onnx/test_exporters_onnx_cli.py -n auto -m "not tensorflow_test and not timm_test" -s --durations=0 diff --git a/.github/workflows/test_export_onnx_cli_timm.yml b/.github/workflows/test_export_onnx_cli_timm.yml index 76a535fceb..b92d5551ba 100644 --- a/.github/workflows/test_export_onnx_cli_timm.yml +++ b/.github/workflows/test_export_onnx_cli_timm.yml @@ -14,20 +14,20 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies for pytorch export - run: | - pip install .[tests,exporters] - - name: Test with unittest - working-directory: tests - run: | - RUN_SLOW=1 pytest exporters/onnx/test_exporters_onnx_cli.py -n auto -k "timm" -s --durations=0 + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for pytorch export + run: | + pip install .[tests,exporters] + - name: Test with unittest + working-directory: tests + run: | + RUN_SLOW=1 pytest exporters/onnx/test_exporters_onnx_cli.py -n auto -k "timm" -s --durations=0 diff --git a/.github/workflows/test_export_onnx_timm.yml b/.github/workflows/test_export_onnx_timm.yml index 339e3e93de..c16d20fbc1 100644 --- a/.github/workflows/test_export_onnx_timm.yml +++ b/.github/workflows/test_export_onnx_timm.yml @@ -14,21 +14,20 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies for pytorch export - run: | - pip install .[tests,exporters] - - name: Test with unittest - working-directory: tests - run: | - RUN_SLOW=1 pytest exporters/onnx/ -s -n auto -k "timm" --durations=0 - + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for pytorch export + run: | + pip install .[tests,exporters] + - name: Test with unittest + working-directory: tests + run: | + RUN_SLOW=1 pytest exporters/onnx/ -s -n auto -k "timm" --durations=0 diff --git a/.github/workflows/test_exporters_common.yml b/.github/workflows/test_exporters_common.yml index 8e8c3360c1..11f6038afe 100644 --- a/.github/workflows/test_exporters_common.yml +++ b/.github/workflows/test_exporters_common.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_exporters_slow.yml b/.github/workflows/test_exporters_slow.yml index b22fdd7fd2..453389d63f 100644 --- a/.github/workflows/test_exporters_slow.yml +++ b/.github/workflows/test_exporters_slow.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_fx.yml b/.github/workflows/test_fx.yml index f0366cf0d1..a4e6dd3cd2 100644 --- a/.github/workflows/test_fx.yml +++ b/.github/workflows/test_fx.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04, macos-13] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_offline.yml b/.github/workflows/test_offline.yml index 90b0108e51..20911fe6db 100644 --- a/.github/workflows/test_offline.yml +++ b/.github/workflows/test_offline.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_onnx.yml b/.github/workflows/test_onnx.yml index 22a1172079..dd1f3bee63 100644 --- a/.github/workflows/test_onnx.yml +++ b/.github/workflows/test_onnx.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04, macos-14] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index a72bedb1ab..0ab95752d0 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -17,8 +17,11 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + transformers-version: ["latest"] os: [ubuntu-20.04, windows-2019, macos-13] + include: + - transformers-version: "4.45.*" + os: ubuntu-20.04 runs-on: ${{ matrix.os }} steps: @@ -33,10 +36,10 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Setup Python ${{ matrix.python-version }} + - name: Setup Python uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: 3.9 - name: Install dependencies run: | @@ -44,6 +47,10 @@ jobs: pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install .[tests,onnxruntime] + - name: Install transformers ${{ matrix.transformers-version }} + if: ${{ matrix.transformers-version != 'latest' }} + run: pip install transformers==${{ matrix.transformers-version }} + - name: Test with pytest (in series) working-directory: tests run: | diff --git a/.github/workflows/test_onnxruntime_slow.yml b/.github/workflows/test_onnxruntime_slow.yml index 20371f7915..c5679e5b30 100644 --- a/.github/workflows/test_onnxruntime_slow.yml +++ b/.github/workflows/test_onnxruntime_slow.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test_optimum_common.yml b/.github/workflows/test_optimum_common.yml index ded149c9b6..5ad42807a5 100644 --- a/.github/workflows/test_optimum_common.yml +++ b/.github/workflows/test_optimum_common.yml @@ -4,9 +4,9 @@ name: Optimum common / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -17,25 +17,24 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.8, 3.9] + python-version: [3.9] os: [ubuntu-20.04, windows-2019, macos-13] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[tests] - ls -l optimum/ - - name: Test with unittest - shell: bash - run: | - # Setting HUGGINGFACE_CO_STAGING to true for only one job of the matrix as the staging tests cannot run in parallel. - export HUGGINGFACE_CO_STAGING=${{ matrix.python-version == '3.8' && matrix.os == 'ubuntu-20.04' }} - pytest tests/test_*.py - + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[tests] + ls -l optimum/ + - name: Test with unittest + shell: bash + run: | + # Setting HUGGINGFACE_CO_STAGING to true for only one job of the matrix as the staging tests cannot run in parallel. + export HUGGINGFACE_CO_STAGING=${{ matrix.python-version == '3.8' && matrix.os == 'ubuntu-20.04' }} + pytest tests/test_*.py diff --git a/.github/workflows/test_utils.yml b/.github/workflows/test_utils.yml index 1ef33ced08..b5f2e27fc6 100644 --- a/.github/workflows/test_utils.yml +++ b/.github/workflows/test_utils.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-20.04, macos-13] - python-version: [3.8, 3.9] + python-version: [3.9] runs-on: ${{ matrix.os }} steps: diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 22b8faf1c2..c8c91a04e4 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -387,137 +387,243 @@ def opt_forward( # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward -def t5_forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - **kwargs, -): - raise_on_head_mask(layer_head_mask) +if check_if_transformers_greater("4.45.99"): - if output_attentions is True: - raise ValueError("output_attentions=True can not be supported with BetterTransformer.") - if len(self.pruned_heads) > 0: - raise ValueError(f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.pruned_heads}.") - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + def t5_forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=query_states.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + return outputs + +else: + + def t5_forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + **kwargs, + ): + raise_on_head_mask(layer_head_mask) + + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if len(self.pruned_heads) > 0: + raise ValueError( + f"Setting `pruned_heads` is unsupported with BetterTransformer, found {self.pruned_heads}." + ) + + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" if key_value_states is None: # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: # cross-attn # (batch_size, n_heads, seq_length, dim_per_head) hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) - dropout_p = self.dropout if self.training else 0.0 - query_states = self.scale * query_states - if position_bias is None and not self.has_relative_attention_bias: - if mask is None: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=dropout_p, is_causal=False - ) - elif mask is not None: + dropout_p = self.dropout if self.training else 0.0 + query_states = self.scale * query_states + if position_bias is None and not self.has_relative_attention_bias: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=mask, dropout_p=dropout_p, is_causal=False ) - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=value_states.device, - dtype=value_states.dtype, - ) - if self.gradient_checkpointing and self.training: - position_bias.requires_grad = True + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=value_states.device, + dtype=value_states.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=value_states.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.has_relative_attention_bias: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias, + dropout_p=dropout_p, + is_causal=False, + ) else: - position_bias = self.compute_bias(real_seq_length, key_length, device=value_states.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) - - if self.has_relative_attention_bias: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=position_bias, dropout_p=dropout_p, is_causal=False - ) - attn_output = unshape(attn_output) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) + attn_output = unshape(attn_output) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - return outputs + return outputs # Adapted from transformers.models.bart.modeling_bart.BartAttention.forward diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 52d28d076d..e8045e695c 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -327,9 +327,9 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): setattr(self, "relative_attention_bias", layer.relative_attention_bias) self.original_layers_mapping["relative_attention_bias"] = "relative_attention_bias" - self.module_mapping = None - + self.layer_idx = getattr(layer, "layer_idx", None) self.is_decoder = layer.is_decoder + self.module_mapping = None def forward(self, *args, **kwargs): return t5_forward(self, *args, **kwargs) diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index a101757b6f..b138862752 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -20,7 +20,13 @@ import torch from packaging.version import parse -from ..utils import check_if_pytorch_greater, is_accelerate_available, recurse_getattr, recurse_setattr +from ..utils import ( + check_if_pytorch_greater, + check_if_torch_greater, + is_accelerate_available, + recurse_getattr, + recurse_setattr, +) from .models import BetterTransformerManager @@ -213,15 +219,18 @@ def transform( hf_config = model.config if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]: raise ValueError( - f"Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type {hf_config.model_type}. As such, there is no need to use `model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention." + f"Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type {hf_config.model_type}. " + "As such, there is no need to use `model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. " + "Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. " + "Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention." ) - # Check if we have to load the model using `accelerate` - if hasattr(model, "hf_device_map"): - load_accelerate = True - hf_device_map = model.hf_device_map - else: - load_accelerate = False + if hasattr(hf_config, "_attn_implementation") and hf_config._attn_implementation == "sdpa": + raise ValueError( + "This model already uses BetterTransformer optimizations from Transformers (torch.nn.functional.scaled_dot_product_attention). " + "As such, there is no need to use `model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. " + "Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention." + ) if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: raise Exception( @@ -241,11 +250,20 @@ def transform( f" Currently supported models are: {BetterTransformerManager.MODEL_MAPPING.keys()}." ) - if parse(torch.__version__) <= parse("1.14"): + if not check_if_torch_greater("2.0"): raise ValueError( f"BetterTransformer requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch." ) + hf_config = model.config + + # Check if we have to load the model using `accelerate` + if hasattr(model, "hf_device_map"): + load_accelerate = True + hf_device_map = model.hf_device_map + else: + load_accelerate = False + if load_accelerate: # Remove the hooks from the original model to avoid weights being on `meta` device. remove_hook_from_module(model, recurse=True) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e77f649f69..9e57128c27 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -155,7 +155,7 @@ class SplinterOnnxConfig(BertOnnxConfig): class DistilBertOnnxConfig(BertOnnxConfig): - DEFAULT_ONNX_OPSET = 11 + DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for transformers>=4.46.0 @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -266,10 +266,18 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -class OPTOnnxConfig(TextDecoderOnnxConfig): - # OPT does not require position_ids input. - DEFAULT_ONNX_OPSET = 13 - NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +# OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46 +if check_if_transformers_greater("4.45.99"): + + class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + +else: + + class OPTOnnxConfig(TextDecoderOnnxConfig): + DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -304,6 +312,15 @@ class Phi3OnnxConfig(PhiOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") + def __init__(self, *args, **kwargs): + # TODO : replace check_if_transformers_greater with is_transformers_available + if check_if_transformers_greater("4.46.0") and not check_if_transformers_greater("4.46.1"): + logger.error( + "Found transformers v4.46.0 while trying to exporting a Phi3 model, this specific version of transformers is not supported. " + "Please upgrade to v4.46.1 or higher, or downgrade your transformers version" + ) + super().__init__(*args, **kwargs) + class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 @@ -480,7 +497,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class T5OnnxConfig(TextSeq2SeqOnnxConfig): - DEFAULT_ONNX_OPSET = 13 + DEFAULT_ONNX_OPSET = 14 # T5 uses aten::triu that requires opset>=14 DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[:-1] + ( T5DummySeq2SeqPastKeyValuesGenerator, ) @@ -2027,6 +2044,7 @@ class TrOCROnnxConfig(TextSeq2SeqOnnxConfig): class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig ATOL_FOR_VALIDATION = 1e-3 + DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator) @@ -2156,8 +2174,21 @@ class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast): DummySeq2SeqPastKeyValuesGenerator, DummyPix2StructInputGenerator, ) - # Min operator needs to support int64, which is the case for opset>=12 - DEFAULT_ONNX_OPSET = 12 + + DEFAULT_ONNX_OPSET = 14 # use 'aten::triu' now which is opset 14 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO : replace check_if_transformers_greater with is_transformers_available + if ( + check_if_transformers_greater("4.46.0") + and not check_if_transformers_greater("4.46.1") + and self._behavior is ConfigBehavior.DECODER + ): + logger.error( + "Found transformers v4.46.0 while trying to exporting a Pix2Struct model, this specific version of transformers is not supported. " + "Please upgrade to v4.46.1 or higher, or downgrade your transformers version" + ) @property def inputs(self): @@ -2310,3 +2341,5 @@ def overwrite_shape_and_generate_input( class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig + + DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 34ed5fcae4..fdfb0e280f 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -34,11 +34,10 @@ if _transformers_version > version.parse("4.34.99"): - from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask + from transformers.modeling_attn_mask_utils import AttentionMaskConverter if _transformers_version >= version.parse("4.36"): from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa else: - _prepare_4d_causal_attention_mask = None _prepare_4d_causal_attention_mask_for_sdpa = None AttentionMaskConverter = None diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 675566ba23..56249bbf5c 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -27,7 +27,7 @@ is_diffusers_available, logging, ) -from ...utils.import_utils import _diffusers_version +from ...utils.import_utils import _diffusers_version, check_if_transformers_greater from ..utils import ( _get_submodels_and_export_configs, ) @@ -89,6 +89,10 @@ } +if check_if_transformers_greater("4.45.99"): + MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt") + + def check_onnxruntime_requirements(minimum_version: version.Version): """ Checks that ONNX Runtime is installed and if version is recent enough. diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index bda3ec98d9..984d7f22eb 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -582,7 +582,8 @@ def _from_pretrained( init_cls = ORTFalconForCausalLM elif config.model_type == "mpt": init_cls = ORTMPTForCausalLM - elif config.model_type == "opt": + # if model was exported with position_ids it means the model was exported with transformers >= v4.46 + elif config.model_type == "opt" and "position_ids" not in input_dims: init_cls = ORTOPTForCausalLM elif config.model_type == "gpt_bigcode": init_cls = ORTGPTBigCodeForCausalLM @@ -839,7 +840,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) - return { "input_ids": input_ids, "past_key_values": past_key_values, diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 5d5044e63e..db7d1f6975 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -29,6 +29,7 @@ TRANSFORMERS_MINIMUM_VERSION, check_if_diffusers_greater, check_if_pytorch_greater, + check_if_torch_greater, check_if_transformers_greater, is_accelerate_available, is_auto_gptq_available, diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 4a57fda79c..35a6294ab5 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -193,6 +193,22 @@ def check_if_diffusers_greater(target_version: str) -> bool: return version.parse(_diffusers_version) >= version.parse(target_version) +def check_if_torch_greater(target_version: str) -> bool: + """ + Checks whether the current install of torch is greater than or equal to the target version. + + Args: + target_version (str): version used as the reference for comparison. + + Returns: + bool: whether the check is True or not. + """ + if not is_torch_available(): + return False + + return torch_version >= version.parse(target_version) + + @contextmanager def require_numpy_strictly_lower(package_version: str, message: str): if not version.parse(np.__version__) < version.parse(package_version): diff --git a/setup.py b/setup.py index 822d8be1b8..82892bfcc8 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIRED_PKGS = [ "coloredlogs", "sympy", - "transformers[sentencepiece]>=4.29", + "transformers>=4.29", "torch>=1.11", "packaging", "numpy", @@ -37,9 +37,9 @@ "diffusers>=0.17.0", "torchaudio", "einops", - "invisible-watermark", "timm", "scikit-learn", + "sentencepiece", "rjieba", ] @@ -54,7 +54,7 @@ "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", - "transformers<4.46.0", + "transformers<4.47.0", ], "onnxruntime-gpu": [ "onnx", @@ -63,10 +63,20 @@ "evaluate", "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. - "transformers<4.46.0", + "transformers<4.47.0", + ], + "exporters": [ + "onnx", + "onnxruntime", + "timm", + "transformers<4.47.0", + ], + "exporters-gpu": [ + "onnx", + "onnxruntime-gpu", + "timm", + "transformers<4.47.0", ], - "exporters": ["onnx", "onnxruntime", "timm", "transformers<4.46.0"], - "exporters-gpu": ["onnx", "onnxruntime-gpu", "timm", "transformers<4.46.0"], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", "tf2onnx", @@ -76,7 +86,7 @@ "h5py", "numpy<1.24.0", "datasets<=2.16", - "transformers[sentencepiece]>=4.26,<4.38", + "transformers>=4.26,<4.38", ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", diff --git a/tests/bettertransformer/test_audio.py b/tests/bettertransformer/test_audio.py index be01a92d44..caca91e27c 100644 --- a/tests/bettertransformer/test_audio.py +++ b/tests/bettertransformer/test_audio.py @@ -35,7 +35,7 @@ class TestsWhisper(unittest.TestCase): def test_error_message(self): - model = AutoModel.from_pretrained("openai/whisper-tiny") + model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager") with self.assertRaises(ValueError) as cm: model = BetterTransformer.transform(model) @@ -82,15 +82,19 @@ def _test_fp16_inference( set_seed(0) if not use_to_operator: - hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0) + hf_random_model = automodel_class.from_pretrained( + model_id, torch_dtype=torch.float16, attn_implementation="eager" + ).to(0) converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=False) - hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0) + hf_random_model = automodel_class.from_pretrained( + model_id, torch_dtype=torch.float16, attn_implementation="eager" + ).to(0) else: - hf_random_model = automodel_class.from_pretrained(model_id).to(0) + hf_random_model = automodel_class.from_pretrained(model_id, attn_implementation="eager").to(0) converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=False) - hf_random_model = automodel_class.from_pretrained(model_id).to(0) + hf_random_model = automodel_class.from_pretrained(model_id, attn_implementation="eager").to(0) hf_random_model = hf_random_model.to(torch.float16) converted_model = converted_model.to(torch.float16) @@ -147,7 +151,7 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int): model_id = MODELS_DICT[model_type] processor = AutoProcessor.from_pretrained(model_id) - model = AutoModel.from_pretrained(model_id) + model = AutoModel.from_pretrained(model_id, attn_implementation="eager") text = ["This is me and me"] if batch_size > 1: @@ -217,14 +221,14 @@ def test_logits(self, model_type: str): inputs = self.prepare_inputs_for_class(model_id, model_type) torch.manual_seed(0) - hf_random_model = AutoModel.from_pretrained(model_id).eval() + hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval() random_config = hf_random_model.config torch.manual_seed(0) converted_model = BetterTransformer.transform(hf_random_model) torch.manual_seed(0) - hf_random_model = AutoModel.from_pretrained(model_id).eval() + hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval() random_config = hf_random_model.config self.assertFalse( diff --git a/tests/bettertransformer/test_common.py b/tests/bettertransformer/test_common.py index 35b89d2ed2..b8bc0a3b3d 100644 --- a/tests/bettertransformer/test_common.py +++ b/tests/bettertransformer/test_common.py @@ -28,7 +28,7 @@ class BetterTransformerIntegrationTests(unittest.TestCase): def test_raise_error_on_double_transform_call(self): - model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel") + model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager") with self.assertRaises(Exception) as cm: bt_model = BetterTransformer.transform(model) @@ -59,7 +59,7 @@ def test_raise_on_save(self, model_type: str): ) for model_id in model_ids: with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname: - hf_model = AutoModel.from_pretrained(model_id).eval() + hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval() bt_model = BetterTransformer.transform(hf_model, keep_original_model=False) bt_model.save_pretrained(tmpdirname) @@ -73,7 +73,7 @@ def test_conversion(self, model_type: str): MODELS_DICT[model_type] if isinstance(MODELS_DICT[model_type], tuple) else (MODELS_DICT[model_type],) ) for model_id in model_ids: - hf_random_model = AutoModel.from_pretrained(model_id) + hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager") converted_model = BetterTransformer.transform(hf_random_model) self.assertTrue( @@ -99,7 +99,7 @@ def test_raise_save_pretrained_error(self, test_name: str, model_type: str, keep ) for model_id in model_ids: # get hf and bt model - hf_model = AutoModel.from_pretrained(model_id) + hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager") # get bt model and invert it bt_model = BetterTransformer.transform(hf_model, keep_original_model=keep_original_model) @@ -145,9 +145,11 @@ def test_raise_activation_fun(self, model_type: str): )() # random config class for the model to test hf_random_config.hidden_act = "silu" - hf_random_model = AutoModel.from_config(hf_random_config).eval() + hf_random_model = AutoModel.from_config(hf_random_config, attn_implementation="eager").eval() + with self.assertRaises(ValueError) as cm: _ = BetterTransformer.transform(hf_random_model, keep_original_model=True) + self.assertTrue("Activation function" in str(cm.exception)) def test_dict_class_consistency(self): diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index bab8f376fc..e2bc6ddc2f 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -131,7 +131,7 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in model_id = MODELS_DICT[model_type] - model = AutoModelForCausalLM.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager") normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config) @@ -167,7 +167,7 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd model_id = MODELS_DICT[model_type] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForCausalLM.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager") if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: if tokenizer.eos_token != "": @@ -224,7 +224,9 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina @require_torch_gpu @require_accelerate def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None): - hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory).eval() + hf_model = AutoModelForCausalLM.from_pretrained( + "gpt2", device_map="auto", max_memory=max_memory, attn_implementation="eager" + ).eval() bt_model = BetterTransformer.transform( hf_model, keep_original_model=keep_original_model, max_memory=max_memory ) diff --git a/tests/bettertransformer/test_encoder.py b/tests/bettertransformer/test_encoder.py index 74aacaed58..7dd42c43b0 100644 --- a/tests/bettertransformer/test_encoder.py +++ b/tests/bettertransformer/test_encoder.py @@ -181,7 +181,9 @@ def check_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_m If this works for roberta, it should work for all other models too. """ - hf_model = AutoModel.from_pretrained("xlm-roberta-base", device_map="auto", max_memory=max_memory).eval() + hf_model = AutoModel.from_pretrained( + "xlm-roberta-base", device_map="auto", max_memory=max_memory, attn_implementation="eager" + ).eval() bt_model = BetterTransformer.transform( hf_model, keep_original_model=keep_original_model, max_memory=max_memory ) diff --git a/tests/bettertransformer/test_encoder_decoder.py b/tests/bettertransformer/test_encoder_decoder.py index 8d05923522..5ce4d62b12 100644 --- a/tests/bettertransformer/test_encoder_decoder.py +++ b/tests/bettertransformer/test_encoder_decoder.py @@ -153,7 +153,7 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd model_id = MODELS_DICT[model_type] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + model = AutoModelForSeq2SeqLM.from_pretrained(model_id, attn_implementation="eager") if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/tests/bettertransformer/test_gpu.py b/tests/bettertransformer/test_gpu.py index b992b90d3c..ada38e408f 100644 --- a/tests/bettertransformer/test_gpu.py +++ b/tests/bettertransformer/test_gpu.py @@ -26,7 +26,9 @@ def timing_cuda(model, num_batches, input_ids, masks, decoder_input_ids): def benchmark(model_name: str, num_batches: int, batch_size: int, max_seqlen: int, is_half: bool): - hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16 if is_half else None).eval() + hf_model = AutoModel.from_pretrained( + model_name, torch_dtype=torch.float16 if is_half else None, attn_implementation="eager" + ).eval() hf_model = hf_model.to("cuda:0") bt_model = BetterTransformer.transform(hf_model, keep_original_model=True) diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 098882180a..f79cbb3451 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -136,10 +136,12 @@ def _test_fp16_inference( torch.manual_seed(0) if not use_to_operator: - hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0) + hf_random_model = automodel_class.from_pretrained( + model_id, torch_dtype=torch.float16, attn_implementation="eager" + ).to(0) converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) else: - hf_random_model = automodel_class.from_pretrained(model_id).to(0) + hf_random_model = automodel_class.from_pretrained(model_id, attn_implementation="eager").to(0) converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) hf_random_model = hf_random_model.to(torch.float16) converted_model = converted_model.to(torch.float16) @@ -169,7 +171,7 @@ def _test_fp16_inference( def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_kwargs): inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs) - hf_random_model = AutoModel.from_pretrained(model_id).eval() + hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval() random_config = hf_random_model.config # I could not obtain reproducible results with `torch.manual_seed` nor with @@ -309,7 +311,7 @@ def _test_train_decoder(self, model_id: str, model_type: str, **kwargs): """ inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **kwargs) - hf_random_model = AutoModel.from_pretrained(model_id).eval() + hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval() bt_model = BetterTransformer.transform(hf_random_model, keep_original_model=True) bt_model.train() @@ -328,7 +330,7 @@ def _test_invert_modules(self, model_id, keep_original_model=False): r""" Test that the inverse converted model and hf model have the same modules """ - hf_model = AutoModel.from_pretrained(model_id) + hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager") hf_modules = list(hf_model.modules()) bt_model = BetterTransformer.transform(hf_model, keep_original_model=keep_original_model) @@ -349,7 +351,7 @@ def _test_invert_modules(self, model_id, keep_original_model=False): def _test_save_load_invertible(self, model_id, keep_original_model=True): with tempfile.TemporaryDirectory() as tmpdirname: - hf_model = AutoModel.from_pretrained(model_id).eval() + hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval() hf_model_state_dict = copy.deepcopy(hf_model.state_dict()) bt_model = BetterTransformer.transform(hf_model, keep_original_model=keep_original_model) @@ -362,7 +364,7 @@ def _test_save_load_invertible(self, model_id, keep_original_model=True): # saving a normal transformers bark model fails because of shared tensors bt_model.save_pretrained(tmpdirname, safe_serialization=hf_model.config.model_type != "bark") - bt_model_from_load = AutoModel.from_pretrained(tmpdirname) + bt_model_from_load = AutoModel.from_pretrained(tmpdirname, attn_implementation="eager") self.assertEqual( set(bt_model.state_dict().keys()), @@ -397,7 +399,7 @@ def _test_invert_model_logits( """ inputs = self.prepare_inputs_for_class(model_id, model_type=model_type, **preprocessor_kwargs) - hf_model = AutoModel.from_pretrained(model_id) + hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager") hf_model = hf_model.eval() with torch.inference_mode(): diff --git a/tests/onnx/test_onnx_export_custom_module.py b/tests/onnx/test_onnx_export_custom_module.py index a144d5cd84..4398c14f01 100644 --- a/tests/onnx/test_onnx_export_custom_module.py +++ b/tests/onnx/test_onnx_export_custom_module.py @@ -24,6 +24,8 @@ import torch from transformers.models.deberta import modeling_deberta + from optimum.utils import check_if_torch_greater + class StableDropoutTestCase(TestCase): """Tests export of StableDropout module.""" @@ -50,8 +52,8 @@ def test_training(self): training=training, ) - # Expected to fail with opset_version < 12 - with self.assertRaises(Exception): + if check_if_torch_greater("2.5"): + # Expected to pass with opset_version < 12 on torch >= 2.5 torch.onnx.export( sd, input, @@ -60,3 +62,14 @@ def test_training(self): do_constant_folding=do_constant_folding, training=training, ) + else: + # Expected to fail with opset_version < 12 on torch < 2.5 + with self.assertRaises(Exception): + torch.onnx.export( + sd, + input, + devnull, + opset_version=11, + do_constant_folding=do_constant_folding, + training=training, + ) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index da450b8e31..597eb581e2 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -54,6 +54,7 @@ AutoModelForTokenClassification, AutoModelForVision2Seq, AutoTokenizer, + GenerationConfig, MBartForConditionalGeneration, Pix2StructForConditionalGeneration, # Pix2Struct does not work with AutoModel PretrainedConfig, @@ -106,7 +107,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.import_utils import is_diffusers_available +from optimum.utils.import_utils import check_if_transformers_greater, is_diffusers_available from optimum.utils.testing_utils import ( grid_parameters, remove_directory, @@ -2326,10 +2327,12 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "llama", "mistral", "mpt", - "phi3", - "qwen2", + "opt", ] + if check_if_transformers_greater("4.40"): + SUPPORTED_ARCHITECTURES.extend(["gemma", "phi3", "qwen2"]) + FULL_GRID = { "model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False, True], @@ -2338,7 +2341,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): ORTMODEL_CLASS = ORTModelForCausalLM TASK = "text-generation" - GENERATION_LENGTH = 100 + GENERATION_LENGTH = 90 SPEEDUP_CACHE = 1.1 @parameterized.expand([(False,), (True,)]) @@ -2411,7 +2414,7 @@ def test_merge_from_onnx_and_save(self, model_arch): self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents) self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents) - @parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 3]})) + @parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 4]})) def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int): use_io_binding = None if use_cache is False: @@ -2474,25 +2477,39 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach # TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers new_tokens = 5 - onnx_outputs = onnx_model.generate( - **tokens, - num_beams=num_beams, - do_sample=False, - min_new_tokens=new_tokens, - max_new_tokens=new_tokens, - eos_token_id=None, - ) + gen_kwargs = { + "max_new_tokens": new_tokens, + "min_new_tokens": new_tokens, + "eos_token_id": None, + "num_beams": num_beams, + } - transformers_outputs = transformers_model.generate( - **tokens, - num_beams=num_beams, - do_sample=False, - min_new_tokens=new_tokens, - max_new_tokens=new_tokens, - eos_token_id=None, - ) + beam_search_gen_config = GenerationConfig(do_sample=False, **gen_kwargs) + + if use_cache and num_beams == 4: + beam_sample_gen_config = GenerationConfig(do_sample=True, **gen_kwargs) + group_beam_search_gen_config = GenerationConfig( + do_sample=False, num_beam_groups=2, diversity_penalty=0.0000001, **gen_kwargs + ) + gen_configs = ( + beam_search_gen_config, + beam_sample_gen_config, + group_beam_search_gen_config, + ) + else: + gen_configs = (beam_search_gen_config,) - self.assertTrue(torch.allclose(onnx_outputs, transformers_outputs)) + for gen_config in gen_configs: + set_seed(SEED) + with torch.no_grad(): + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + set_seed(SEED) + onnx_outputs = onnx_model.generate(**tokens, generation_config=gen_config) + + self.assertTrue( + torch.equal(onnx_outputs, transformers_outputs), + f"Failed with generation config : {gen_config}, transformers outputs {transformers_outputs}, ONNX model outputs {onnx_outputs}", + ) gc.collect() diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 5071d0081a..e3d5423785 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -125,6 +125,7 @@ "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "lewtun/tiny-random-mt5", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", + "opt": "hf-internal-testing/tiny-random-OPTModel", "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver_text": "hf-internal-testing/tiny-random-language_perceiver", "perceiver_vision": "hf-internal-testing/tiny-random-vision_perceiver_conv",