Skip to content

Commit

Permalink
Fix tapas scatter (huggingface#20149)
Browse files Browse the repository at this point in the history
* First draft

* Remove scatter dependency

* Add require_torch

* update vectorized sum test, add clone call

* remove artifacts

* fix style

* fix style v2

* remove "scatter" mentions from the code base

* fix isort error

Co-authored-by: Niels Rogge <[email protected]>
Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2022
1 parent f711d68 commit 78a471f
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 194 deletions.
3 changes: 0 additions & 3 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ RUN python3 -m pip uninstall -y flax jax
# TODO: remove this line once the conflict is resolved in these libraries.
RUN python3 -m pip install --no-cache-dir git+https://github.com/onnx/tensorflow-onnx.git@ddca3a5eb2d912f20fe7e0568dd1a3013aee9fa3

# Use installed torch version for `torch-scatter` to avid to deal with PYTORCH='pre'.
# If torch is nightly version, the link is likely to be invalid, but the installation falls back to the latest torch-scatter
RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+$CUDA.html
RUN python3 -m pip install --no-cache-dir intel_extension_for_pytorch==$INTEL_TORCH_EXT+cpu -f https://software.intel.com/ipex-whl-stable

RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract
Expand Down
1 change: 0 additions & 1 deletion docker/transformers-doc-builder/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ RUN apt-get -y update && apt-get install -y libsndfile1-dev && apt install -y te
# Torch needs to be installed before deepspeed
RUN python3 -m pip install --no-cache-dir ./transformers[deepspeed]

RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python -c "from torch import version; print(version.__version__.split('+')[0])")+cpu.html
RUN python3 -m pip install --no-cache-dir torchvision git+https://github.com/facebookresearch/detectron2.git pytesseract
RUN python3 -m pip install --no-cache-dir pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com
RUN python3 -m pip install -U "itsdangerous<2.1.0"
Expand Down
6 changes: 0 additions & 6 deletions docker/transformers-past-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,4 @@ RUN python3 ./transformers/utils/past_ci_versions.py --framework $FRAMEWORK --ve
RUN echo "INSTALL_CMD = $INSTALL_CMD"
RUN $INSTALL_CMD

# Having installation problems for torch-scatter with torch <= 1.6. Disable so we have the same set of tests.
# (This part will be removed once the logic of using `past_ci_versions.py` is used in other Dockerfile files.)
# # Use installed torch version for `torch-scatter`.
# # (The env. variable $CUDA is defined in `past_ci_versions.py`)
# RUN [ "$FRAMEWORK" = "pytorch" ] && python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+$CUDA.html || echo "torch-scatter not to be installed"

RUN python3 -m pip install -U "itsdangerous<2.1.0"
3 changes: 1 addition & 2 deletions docs/source/en/model_doc/tapas.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ To summarize:

<frameworkcontent>
<pt>
Initializing a model with a pre-trained base and randomly initialized classification heads from the hub can be done as shown below. Be sure to have installed the
[torch-scatter](https://github.com/rusty1s/pytorch_scatter) dependency:
Initializing a model with a pre-trained base and randomly initialized classification heads from the hub can be done as shown below.

```py
>>> from transformers import TapasConfig, TapasForQuestionAnswering
Expand Down
59 changes: 20 additions & 39 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_scatter_available,
is_sentencepiece_available,
is_speech_available,
is_tensorflow_text_available,
Expand Down Expand Up @@ -784,28 +783,6 @@
]
)

try:
if not is_scatter_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_scatter_objects

_import_structure["utils.dummy_scatter_objects"] = [
name for name in dir(dummy_scatter_objects) if not name.startswith("_")
]
else:
_import_structure["models.tapas"].extend(
[
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TapasForMaskedLM",
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
)


# PyTorch-backed objects
try:
Expand Down Expand Up @@ -2027,6 +2004,17 @@
"Swinv2PreTrainedModel",
]
)
_import_structure["models.tapas"].extend(
[
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TapasForMaskedLM",
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
)
_import_structure["models.t5"].extend(
[
"T5_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -3845,22 +3833,6 @@
TableTransformerPreTrainedModel,
)

try:
if not is_scatter_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_scatter_objects import *
else:
from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -4854,6 +4826,15 @@
T5PreTrainedModel,
load_tf_weights_in_t5,
)
from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)
from .models.time_series_transformer import (
TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TimeSeriesTransformerForPrediction,
Expand Down
1 change: 0 additions & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
is_rjieba_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_scatter_available,
is_scipy_available,
is_sentencepiece_available,
is_sklearn_available,
Expand Down
29 changes: 6 additions & 23 deletions src/transformers/models/tapas/modeling_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,14 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scatter_available,
logging,
replace_return_docstrings,
requires_backends,
)
from .configuration_tapas import TapasConfig


logger = logging.get_logger(__name__)

# soft dependency
if is_scatter_available():
try:
from torch_scatter import scatter
except OSError:
logger.error(
"TAPAS models are not usable since `torch_scatter` can't be loaded. "
"It seems you have `torch_scatter` installed with the wrong CUDA version. "
"Please try to reinstall it following the instructions here: https://github.com/rusty1s/pytorch_scatter."
)

_CONFIG_FOR_DOC = "TapasConfig"
_TOKENIZER_FOR_DOC = "TapasTokenizer"
_TOKENIZER_FOR_DOC = "google/tapas-base"
Expand Down Expand Up @@ -862,7 +849,6 @@ class TapasModel(TapasPreTrainedModel):
"""

def __init__(self, config, add_pooling_layer=True):
requires_backends(self, "scatter")
super().__init__(config)
self.config = config

Expand Down Expand Up @@ -1798,12 +1784,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
# changed "view" by "reshape" in the following line
flat_values = values.reshape(flattened_shape.tolist())

segment_means = scatter(
src=flat_values,
index=flat_index.indices.long(),
dim=0,
dim_size=int(flat_index.num_segments),
reduce=segment_reduce_fn,
out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype)
segment_means = out.scatter_reduce(
dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False
)

# Unflatten the values.
Expand All @@ -1816,7 +1799,7 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
dim=0,
)

output_values = segment_means.view(new_shape.tolist())
output_values = segment_means.clone().view(new_shape.tolist())
output_index = range_index_map(index.batch_shape(), index.num_segments)
return output_values, output_index

Expand Down Expand Up @@ -1901,7 +1884,7 @@ def reduce_max(values, index, name="segmented_reduce_max"):
output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
"""
return _segment_reduce(values, index, "max", name)
return _segment_reduce(values, index, "amax", name)


def reduce_min(values, index, name="segmented_reduce_min"):
Expand All @@ -1928,7 +1911,7 @@ def reduce_min(values, index, name="segmented_reduce_min"):
output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
"""
return _segment_reduce(values, index, "min", name)
return _segment_reduce(values, index, "amin", name)


# End of everything related to segmented tensors
Expand Down
19 changes: 0 additions & 19 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
is_pytorch_quantization_available,
is_rjieba_available,
is_safetensors_available,
is_scatter_available,
is_scipy_available,
is_sentencepiece_available,
is_soundfile_availble,
Expand Down Expand Up @@ -319,16 +318,6 @@ def require_intel_extension_for_pytorch(test_case):
)(test_case)


def require_torch_scatter(test_case):
"""
Decorator marking a test that requires PyTorch scatter.
These tests are skipped when PyTorch scatter isn't installed.
"""
return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case)


def require_tensorflow_probability(test_case):
"""
Decorator marking a test that requires TensorFlow probability.
Expand Down Expand Up @@ -405,14 +394,6 @@ def require_pytesseract(test_case):
return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)


def require_scatter(test_case):
"""
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case)


def require_pytorch_quantization(test_case):
"""
Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
Expand Down
1 change: 0 additions & 1 deletion src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_scatter_available,
is_scipy_available,
is_sentencepiece_available,
is_sklearn_available,
Expand Down
42 changes: 42 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5099,6 +5099,48 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_backends(load_tf_weights_in_t5, ["torch"])


TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None


class TapasForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class TapasForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class TapasForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class TapasModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class TapasPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def load_tf_weights_in_tapas(*args, **kwargs):
requires_backends(load_tf_weights_in_tapas, ["torch"])


TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None


Expand Down
45 changes: 0 additions & 45 deletions src/transformers/utils/dummy_scatter_objects.py

This file was deleted.

Loading

0 comments on commit 78a471f

Please sign in to comment.