From ce99901e2ea31528c19161da35798f09c424c8e5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:31:40 +0530 Subject: [PATCH 01/28] feat: package updates with python311 --- algorithmic_efficiency/random_utils.py | 8 +-- docker/Dockerfile | 29 ++++++++- setup.cfg | 86 +++++++++++++------------- 3 files changed, 75 insertions(+), 48 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index cf1ea6c32..31317047e 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,8 +18,8 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 -MIN_INT32 = -MAX_INT32 +MAX_UINT32 = 2**31 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b72aea86..24d05b495 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,34 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg + +# Install prerequisites +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ + curl \ + libbz2-dev \ + liblzma-dev + +# Download and install Python 3.11 +RUN cd /tmp \ + && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ + && tar -xvzf Python-3.11.0.tgz \ + && cd Python-3.11.0 \ + && ./configure --enable-optimizations \ + && make -j$(nproc) \ + && make altinstall + +# Create symlinks for python and pip (use 'pip' instead of 'pip3') +RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ + && ln -s /usr/local/bin/pip3.11 /usr/bin/pip + RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz diff --git a/setup.cfg b/setup.cfg index 4afefd164..deeb1c6c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence [options] @@ -34,22 +35,22 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==1.4.0 + absl-py==2.1.1 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 + networkx==3.2.1 + docker==7.1.0 + numpy>=1.26.4 + pandas==2.2.3 + tensorflow==2.18.0 + tensorflow-datasets==4.9.7 + tensorflow-addons==0.23.0 gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 + psutil==6.1.0 + clu==0.0.12 + matplotlib>=3.9.2 tabulate==0.9.0 -python_requires = >=3.8 + wandb==0.18.7 +python_requires = >=3.11 ############################################################################### @@ -79,78 +80,77 @@ full_dev = # Dependencies for developing the package dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 + isort==5.13.2 + pylint==3.3.1 + pytest==8.3.3 + yapf==0.43.0 + pre-commit==4.0.1 # Workloads # criteo1tb = - scikit-learn==1.2.2 + scikit-learn==1.5.2 fastmri = - h5py==3.8.0 - scikit-image==0.20.0 + h5py==3.12.1 + scikit-image==0.24.0 ogbg = jraph==0.0.6.dev0 - scikit-learn==1.2.2 + scikit-learn==1.5.2 librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - sacrebleu==1.3.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 + sacrebleu==2.4.3 # Frameworks # # JAX Core jax_core_deps = - flax==0.6.10 - optax==0.1.5 + flax==0.10.1 + optax==0.2.4 # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 + chex==0.1.87 + ml_dtypes==0.4.1 + protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 + jax==0.4.35 + jaxlib==0.4.35 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 + jax==0.4.35 + jaxlib==0.4.35 + jax-cuda12-plugin[with_cuda]==0.4.35 + jax-cuda12-pjrt==0.4.35 %(jax_core_deps)s # PyTorch CPU pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 -# wandb -wandb = - wandb==0.16.5 ############################################################################### # Linting Configurations # From 21fb3f902d5744c8331be89f896c2376977f7f12 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:46:17 +0530 Subject: [PATCH 02/28] fix: absl package version change --- docker/Dockerfile | 12 +++++++----- setup.cfg | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 24d05b495..497ffb2c1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -28,9 +28,9 @@ RUN apt-get update && apt-get install -y \ # Download and install Python 3.11 RUN cd /tmp \ - && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ - && tar -xvzf Python-3.11.0.tgz \ - && cd Python-3.11.0 \ + && wget https://www.python.org/ftp/python/3.11.10/Python-3.11.10.tgz \ + && tar -xvzf Python-3.11.10.tgz \ + && cd Python-3.11.10 \ && ./configure --enable-optimizations \ && make -j$(nproc) \ && make altinstall @@ -55,11 +55,13 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ +RUN pip install --upgrade pip + # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" -ARG branch="main" +ARG branch="python311" ARG framework="both" -ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git +ARG git_url=https://github.com/init-22/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch diff --git a/setup.cfg b/setup.cfg index deeb1c6c4..e952513df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==2.1.1 + absl-py==2.1.0 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 From 67b9f15108486a1a29b348031e1b50a82fa55b40 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:09:04 +0530 Subject: [PATCH 03/28] fix: pytorch version change --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index e952513df..a74faa197 100644 --- a/setup.cfg +++ b/setup.cfg @@ -141,15 +141,15 @@ jax_gpu = # PyTorch CPU pytorch_cpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.5.1 + torchvision==0.20.1 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.5.1 + torchvision==0.20.1 ############################################################################### From 78df36f2f0f173ad651b81527cda8d55f85028b0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:42:26 +0530 Subject: [PATCH 04/28] fix: tf version to use numpy < 2 --- docker/Dockerfile | 2 -- setup.cfg | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 497ffb2c1..88fc55243 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -87,8 +87,6 @@ RUN if [ "$framework" = "jax" ] ; then \ RUN cd /algorithmic-efficiency && pip install -e '.[full]' -RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull diff --git a/setup.cfg b/setup.cfg index a74faa197..2a300469a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,7 @@ install_requires = docker==7.1.0 numpy>=1.26.4 pandas==2.2.3 - tensorflow==2.18.0 + tensorflow==2.17.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -105,7 +105,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 - tensorflow-text==2.18.0 + tensorflow-text==2.17.0 sacrebleu==2.4.3 # Frameworks # From 2584416e8cc82bb61ef7a1d2a395a25da919f93f Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 19:09:40 +0530 Subject: [PATCH 05/28] fix: librispeech requirement of tf-text rolled back to v2.17 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 2a300469a..078b694b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -100,7 +100,7 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.18.0 + tensorflow-text==2.17.0 pydub==0.25.1 wmt = From d603ce921b211918ce0e3d27742032f5e7ece674 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 19:11:38 +0530 Subject: [PATCH 06/28] fix: using the main repo and branch for testing --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 88fc55243..ee9136cbf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -59,9 +59,9 @@ RUN pip install --upgrade pip # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" -ARG branch="python311" +ARG branch="main" ARG framework="both" -ARG git_url=https://github.com/init-22/algorithmic-efficiency.git +ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch From be68f8cbf4a528804c78eff886ffd7e36e04fca8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 16 Nov 2024 13:57:11 +0530 Subject: [PATCH 07/28] fix: overflow error resolved and PRNGKey to key --- algorithmic_efficiency/checkpoint_utils.py | 2 +- algorithmic_efficiency/random_utils.py | 10 +++++----- setup.cfg | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..04dad0eb7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -231,7 +231,7 @@ def save_checkpoint(framework: str, target=checkpoint_state, step=global_step, overwrite=True, - keep=np.Inf if save_intermediate_checkpoints else 1) + keep=np.inf if save_intermediate_checkpoints else 1) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 31317047e..93dc263bd 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**31 +MAX_UINT32 = 2**32-1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -26,11 +26,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.PRNGKey(seed) + return jax_rng.key(seed) return _PRNGKey(seed) diff --git a/setup.cfg b/setup.cfg index 078b694b8..6e6a1c957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,9 +39,9 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=1.26.4 + numpy>=2.1.3 pandas==2.2.3 - tensorflow==2.17.0 + tensorflow==2.18.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -100,12 +100,12 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 sacrebleu==2.4.3 # Frameworks # From e890c893297a6e64cbfdc6d63f87ee7f7b4d385a Mon Sep 17 00:00:00 2001 From: init-22 Date: Wed, 20 Nov 2024 19:13:50 +0530 Subject: [PATCH 08/28] fix: minor changes in docs --- GETTING_STARTED.md | 2 +- algorithmic_efficiency/logger_utils.py | 2 +- setup.cfg | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 006b972ec..aa493bc9f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -35,7 +35,7 @@ The specs on the benchmarking machines are: > **Prerequisites:** > -> - Python minimum requirement >= 3.8 +> - Python minimum requirement >= 3.11 > - CUDA 12.1 > - NVIDIA Driver version 535.104.05 diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 609d996e6..155e55356 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict: system_software_info['os_platform'] = \ platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' system_software_info['python_version'] = platform.python_version( - ) # Ex. '3.8.10' + ) # Ex. '3.11.10' system_software_info['python_compiler'] = platform.python_compiler( ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive diff --git a/setup.cfg b/setup.cfg index 6e6a1c957..5023f1ba6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,9 +21,6 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: Apache Software License Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence From 1bc2a7b2d5de45309bbcab035bff587c9f19ef27 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 13:07:46 +0530 Subject: [PATCH 09/28] fix: changing the python versions in workflow to pass the tests --- .github/workflows/CI.yml | 48 +++++++++++++------------- .github/workflows/linting.yml | 12 +++---- .github/workflows/traindiffs_tests.yml | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 05d94e896..fe2441bfe 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -25,10 +25,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -42,10 +42,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -59,10 +59,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -77,10 +77,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -96,10 +96,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -113,10 +113,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -130,10 +130,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -148,10 +148,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -166,10 +166,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -184,10 +184,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest @@ -208,10 +208,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 89b5ef288..628fc012b 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install pylint run: | python -m pip install --upgrade pip @@ -27,10 +27,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install isort run: | python -m pip install --upgrade pip @@ -43,10 +43,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install yapf run: | python -m pip install --upgrade pip diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml index 382f0dfe1..a2fdcb453 100644 --- a/.github/workflows/traindiffs_tests.yml +++ b/.github/workflows/traindiffs_tests.yml @@ -3,7 +3,7 @@ name: Containerized Training Differences Tests Jax vs PyTorch on: pull_request: branches: - - 'main' + - 'python311' jobs: build_and_push_docker_image: From 7a0fee3224e3d4e8602a2aca2819358bf97acf00 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 22:59:52 +0530 Subject: [PATCH 10/28] fix: changing numpy compatible version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 5023f1ba6..0aa4dce49 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=2.1.3 + numpy>=2.0.2 pandas==2.2.3 tensorflow==2.18.0 tensorflow-datasets==4.9.7 From 7cdea1638ceb2a3c0019e95c0a63f0c36605064a Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 23:07:52 +0530 Subject: [PATCH 11/28] adding key_data to check the CI tests --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -210,7 +210,7 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -336,7 +336,7 @@ def train_once( step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, input_queue, From 7264c3f80d0bd38a1c50f107d715765a7c76dcdc Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 1 Dec 2024 14:41:05 +0530 Subject: [PATCH 12/28] fix: updated packge of sacrebleu changed the way it used to work, hence using the corpus_bleu from the main package --- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 3 ++- setup.cfg | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..327ca34ad 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,6 +5,7 @@ from absl import logging import jax +import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -162,7 +163,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index 0aa4dce49..23e86a13b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,6 @@ wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 sacrebleu==2.4.3 - # Frameworks # # JAX Core From abbdc8262917fd8e38ba954f8cdaf478a5d8d1c7 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 1 Dec 2024 16:11:01 +0530 Subject: [PATCH 13/28] fix: temporarily commenting tfa --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..d0bbecb8f 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,7 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image +#from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. From 86029a742094a653e5bf9a6f17f0d42c0990671d Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:10:24 +0530 Subject: [PATCH 14/28] fix: explicitly using mask kwarg to use MultiHeadDotProductAttention and also using sacrebleu --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 1 + algorithmic_efficiency/workloads/wmt/wmt_jax/models.py | 6 +++--- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index d0bbecb8f..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + #from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index e4b5cd014..7bbc0b168 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -224,7 +224,7 @@ def __call__(self, inputs, encoder_mask=None): dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * x, x, - encoder_mask) + mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +288,7 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask) + decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -311,7 +311,7 @@ def __call__(self, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * y, encoded, - encoder_decoder_mask) + mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..442c85899 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,6 +12,7 @@ import jax.numpy as jnp import numpy as np import optax +import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -203,7 +204,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( From aca45a2b1e1df7e42a5108df8e30d49baf6ef6e2 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:42:21 +0530 Subject: [PATCH 15/28] fix: using flax.core.pop instead of variables.pop, better way to update batch_stats --- .../workloads/imagenet_resnet/imagenet_jax/workload.py | 7 ++++--- .../workloads/imagenet_vit/imagenet_jax/workload.py | 3 ++- .../librispeech_conformer/librispeech_jax/workload.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..8ab4adbb9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -79,8 +80,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() # Create a shallow copy + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -111,7 +112,7 @@ def init_model_fn( input_shape = (1, 224, 224, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..5f826d035 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -4,6 +4,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax import jax.numpy as jnp @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState, variables = jax.jit( model.init)({'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") return params, model_state def init_model_fn( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..d805e8b17 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils +from flax.core import pop import flax.linen as nn import jax from jax import lax @@ -89,7 +90,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -374,8 +375,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state From 2618c5e6b1dcbdf48c2625f4cfbdca93fdc53993 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:50:27 +0530 Subject: [PATCH 16/28] fix: changing the traindiffs_tests branch to main again --- .github/workflows/traindiffs_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml index a2fdcb453..382f0dfe1 100644 --- a/.github/workflows/traindiffs_tests.yml +++ b/.github/workflows/traindiffs_tests.yml @@ -3,7 +3,7 @@ name: Containerized Training Differences Tests Jax vs PyTorch on: pull_request: branches: - - 'python311' + - 'main' jobs: build_and_push_docker_image: From 8c9062564c920e7fea8c3ee6abc8fce51d663c82 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 23:23:09 +0530 Subject: [PATCH 17/28] fix: unfreeze() in test_param_shapes expect FrozenDict also added flax.core.pop instead of variables.pop --- .../workloads/cifar/cifar_jax/workload.py | 7 ++++--- tests/test_param_shapes.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..6ec90b99a 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -75,8 +76,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -93,7 +94,7 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..4ad56c873 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -3,6 +3,7 @@ import jax import numpy as np import pytest +from flax.core import FrozenDict # isort: skip_file # pylint:disable=line-too-long @@ -51,8 +52,11 @@ def test_param_shapes(workload): jax_workload, pytorch_workload = get_workload(workload) # Compare number of parameter tensors of both models. + jax_workload_param_shapes = jax_workload.param_shapes + if isinstance(jax_workload_param_shapes, dict): + jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload.param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) if workload == 'wmt': From 1b587b75890c39c3b3ebf5359b7f82b260e06bc6 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:30:41 +0530 Subject: [PATCH 18/28] fix: formatting changes with yapf --- algorithmic_efficiency/profiler.py | 4 +-- algorithmic_efficiency/random_utils.py | 2 +- .../workloads/cifar/cifar_jax/workload.py | 2 +- .../fastmri/fastmri_pytorch/workload.py | 4 +-- .../imagenet_jax/randaugment.py | 8 ++---- .../imagenet_pytorch/workload.py | 4 +-- .../librispeech_jax/models.py | 10 +++---- .../librispeech_jax/spectrum_augmenter.py | 4 +-- .../librispeech_jax/workload.py | 2 +- .../librispeech_pytorch/workload.py | 9 +++--- .../librispeech_jax/models.py | 10 +++---- .../workloads/mnist/workload.py | 7 ++--- .../workloads/wmt/wmt_jax/models.py | 13 ++++----- .../workloads/wmt/wmt_pytorch/models.py | 4 +-- .../external_tuning/jax_nadamw_full_budget.py | 10 ++++--- .../jax_nadamw_target_setting.py | 10 ++++--- .../self_tuning/jax_nadamw_full_budget.py | 10 ++++--- .../self_tuning/jax_nadamw_target_setting.py | 10 ++++--- .../paper_baselines/nadamw/jax/submission.py | 10 ++++--- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 +++++++------------ .../target_setting_algorithms/jax_nadamw.py | 10 ++++--- submission_runner.py | 4 +-- tests/modeldiffs/wmt/compare.py | 6 ++-- .../modeldiffs/wmt_attention_temp/compare.py | 6 ++-- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 ++-- tests/modeldiffs/wmt_post_ln/compare.py | 6 ++-- 27 files changed, 98 insertions(+), 109 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index fa2a1bee2..d73efd964 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) + for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 93dc263bd..b5b30ce22 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**32-1 +MAX_UINT32 = 2**32 - 1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 6ec90b99a..dd4643a60 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -76,7 +76,7 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() + new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..a2f0828e3 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,9 +252,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index af1b763c1..94c66033a 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -313,8 +313,7 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, + tf.equal(step, 0), lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -549,7 +548,6 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), lambda: image) return image diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..0ed944191 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,9 +309,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..db8cbc70a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,12 +442,10 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index d805e8b17..64e41989f 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -375,7 +375,7 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() + new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..31d069e88 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -260,8 +260,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -329,9 +330,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index f9eb732e9..c2fe540a6 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,12 +273,10 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..ad950b869 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index 7bbc0b168..97fee032f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -222,9 +222,8 @@ def __call__(self, inputs, encoder_mask=None): use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * x, - x, - mask=encoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * x, x, mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +287,8 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) + decode=cfg.decode)( + cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -309,9 +309,8 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * y, - encoded, - mask=encoder_decoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..089f1bfbb 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..ad4d8e6f5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..bde851468 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..4122be181 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..6b5faa6b8 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..ad4d8e6f5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..d33daadb8 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,8 +67,9 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, + params, + updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -80,8 +81,7 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree_map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..722dab06b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,17 +1809,13 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree_map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree_map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1923,8 +1919,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors + >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2442,9 +2438,7 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree_map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2453,9 +2447,7 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree_map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..fc866f80a 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,8 +108,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -124,8 +125,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/submission_runner.py b/submission_runner.py index 0024c35d4..a6bea1aa8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 41fc5ee17..8f9154f53 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 92ce4eb44..ff7103d43 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index b8d860479..d24d818a2 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 3f5469d8d..7d0556345 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From c65d93e5b4adfa6e493e6101048738afd8dc15d9 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:37:32 +0530 Subject: [PATCH 19/28] fix: running yapf again with 0.32, earlier using 0.43 --- algorithmic_efficiency/profiler.py | 4 ++-- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +++- .../imagenet_resnet/imagenet_jax/randaugment.py | 8 +++++--- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +++- .../librispeech_conformer/librispeech_jax/models.py | 10 ++++++---- .../librispeech_jax/spectrum_augmenter.py | 4 ++-- .../librispeech_pytorch/workload.py | 9 +++++---- .../librispeech_deepspeech/librispeech_jax/models.py | 10 ++++++---- algorithmic_efficiency/workloads/mnist/workload.py | 7 +++++-- .../workloads/wmt/wmt_pytorch/models.py | 4 ++-- setup.cfg | 2 +- submission_runner.py | 4 ++-- tests/modeldiffs/wmt/compare.py | 6 +++--- tests/modeldiffs/wmt_attention_temp/compare.py | 6 +++--- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +++--- tests/modeldiffs/wmt_post_ln/compare.py | 6 +++--- 16 files changed, 54 insertions(+), 40 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index d73efd964..fa2a1bee2 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) - for a, d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) for a, + d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index a2f0828e3..74f6aa13d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,7 +252,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 94c66033a..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -313,7 +313,8 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), lambda: im, + tf.equal(step, 0), + lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -548,6 +549,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, selected_args=args: selected_func( - image, *selected_args), lambda: image) + lambda selected_func=func, + selected_args=args: selected_func(image, *selected_args), + lambda: image) return image diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 0ed944191..3549911fa 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,7 +309,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db8cbc70a..ed05f4335 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,10 +442,12 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 31d069e88..155b30920 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -260,9 +260,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -330,7 +329,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index c2fe540a6..f9eb732e9 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,10 +273,12 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index ad950b869..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,7 +46,8 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'],}) + 'targets': x['label'], + }) is_train = split == 'train' if cache: @@ -213,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 089f1bfbb..a1c7ce15e 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/setup.cfg b/setup.cfg index 23e86a13b..e8044fe02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,7 +80,7 @@ dev = isort==5.13.2 pylint==3.3.1 pytest==8.3.3 - yapf==0.43.0 + yapf==0.32.0 pre-commit==4.0.1 # Workloads # diff --git a/submission_runner.py b/submission_runner.py index a6bea1aa8..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 8f9154f53..41fc5ee17 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index ff7103d43..92ce4eb44 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index d24d818a2..b8d860479 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 7d0556345..3f5469d8d 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From 3afd1dff5e6bf0780c5ff77e2e7daedba74928cb Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:39:03 +0530 Subject: [PATCH 20/28] fix: running yapf again with 0.32, earlier using 0.43 --- .../external_tuning/jax_nadamw_full_budget.py | 10 +++---- .../jax_nadamw_target_setting.py | 10 +++---- .../self_tuning/jax_nadamw_full_budget.py | 10 +++---- .../self_tuning/jax_nadamw_target_setting.py | 10 +++---- .../paper_baselines/nadamw/jax/submission.py | 10 +++---- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 ++++++++++++------- .../target_setting_algorithms/jax_nadamw.py | 10 +++---- 8 files changed, 46 insertions(+), 50 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index ad4d8e6f5..98193f01f 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index bde851468..66fdc4ebb 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4122be181..4f53afb56 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 6b5faa6b8..60a1f784d 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index ad4d8e6f5..98193f01f 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index d33daadb8..85b3d7441 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, - params, - updates) + noised_params = jax.tree_util.tree_map( + lambda p, u: p + rho * u, params, updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree_map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, + lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 722dab06b..725529cae 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree_map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree_map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), errors - >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), + errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2438,7 +2442,9 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree_map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2447,7 +2453,9 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree_map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index fc866f80a..21f2a7b2b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,9 +108,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From 6ff2010d884e9d14911beab6dbce1a546a0a6213 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 21:21:03 +0530 Subject: [PATCH 21/28] fix: latest versions of typing dont support Text instead str is recommended --- algorithmic_efficiency/halton.py | 14 +++++++------- .../workloads/wmt/wmt_jax/workload.py | 2 +- .../workloads/wmt/wmt_pytorch/workload.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index 9eb30861d..d710e3fce 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -10,13 +10,13 @@ import functools import itertools import math -from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from absl import logging from numpy import random -_SweepSequence = List[Dict[Text, Any]] -_GeneratorFn = Callable[[float], Tuple[Text, float]] +_SweepSequence = List[Dict[str, Any]] +_GeneratorFn = Callable[[float], Tuple[str, float]] def generate_primes(n: int) -> List[int]: @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: Text, +def _generate_double_point(name: str, min_val: float, max_val: float, - scaling: Text, + scaling: str, halton_point: float) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]: return start, end -def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: +def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints return functools.partial(_generate_double_point, name, @@ -244,7 +244,7 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( - name: Text, search_points: Union[_DiscretePoints, + name: str, search_points: Union[_DiscretePoints, Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 442c85899..72108c9d9 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 327ca34ad..b554b2ab3 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload From 55bacbd493c425fda147bc59aa97341f73b1ef17 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 21:24:18 +0530 Subject: [PATCH 22/28] fix: minor yapf --- algorithmic_efficiency/halton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index d710e3fce..1f36b07bf 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -245,7 +245,7 @@ def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( name: str, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, name, From 5eac985fcefc7fa0f93c2e4f28e0d71ca6db7d3d Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 7 Dec 2024 21:07:21 +0530 Subject: [PATCH 23/28] fix: going back to sacrebleu v1.3.1 --- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 5 ++--- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 5 ++--- setup.cfg | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 72108c9d9..046d5e469 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,11 +12,10 @@ import jax.numpy as jnp import numpy as np import optax -import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -204,7 +203,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index b554b2ab3..0ba49c2f6 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,7 +5,6 @@ from absl import logging import jax -import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -16,7 +15,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -163,7 +162,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index e8044fe02..a7c224407 100644 --- a/setup.cfg +++ b/setup.cfg @@ -103,7 +103,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==2.4.3 + sacrebleu==1.3.1 # Frameworks # # JAX Core From 786771169b0f9bafe241692ac9411d30fccce62d Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:13:16 +0530 Subject: [PATCH 24/28] feat: custom tf_addons support in TF2.18 --- .../imagenet_jax/custom_tf_addons.py | 433 ++++++++++++++++++ .../imagenet_jax/randaugment.py | 16 +- 2 files changed, 441 insertions(+), 8 deletions(-) create mode 100644 algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py new file mode 100644 index 000000000..eda67d226 --- /dev/null +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -0,0 +1,433 @@ +""" +Note: +The following code is adapted from: +https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image + + +""" + +import math +from typing import Callable, List, Optional, Union + +import numpy as np +import tensorflow as tf + +_IMAGE_DTYPES = { + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, +} + +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +def to_4D_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D `Tensor`. + + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message="`image` must be 2/3/4D tensor") + ]): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4D_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. + + Returns: + `ndims`-D `Tensor` with the same type. + """ + with tf.control_dependencies( + [tf.debugging.assert_rank(image, 4, + message="`image` must be 4D tensor")]): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4D_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def transform( + images: TensorLike, + transforms: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Applies the given transform(s) to the image(s). + + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or "transform"): + image_or_images = tf.convert_to_tensor(images, name="images") + transform_or_transforms = tf.convert_to_tensor( + transforms, name="transforms", dtype=tf.dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4D_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + if output_shape is None: + output_shape = tf.shape(images)[1:3] + + output_shape = tf.convert_to_tensor( + output_shape, tf.dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise ValueError("transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + transforms = transform_or_transforms + raise ValueError("transforms should have rank 1 or 2, but got rank %d" % + len(transforms.get_shape())) + + fill_value = tf.convert_to_tensor( + fill_value, dtype=tf.float32, name="fill_value") + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, + ) + return from_4D_image(output, original_ndims) + + +def angles_to_projective_transforms( + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, +) -> tf.Tensor: + """Returns projective transform(s) for the given angle(s). + + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or "angles_to_projective_transforms"): + angle_or_angles = tf.convert_to_tensor( + angles, name="angles", dtype=tf.dtypes.float32) + if len(angle_or_angles.get_shape()) == 0: + angles = angle_or_angles[None] + elif len(angle_or_angles.get_shape()) == 1: + angles = angle_or_angles + else: + raise ValueError("angles should have rank 0 or 1.") + cos_angles = tf.math.cos(angles) + sin_angles = tf.math.sin(angles) + x_offset = ((image_width - 1) - + (cos_angles * (image_width - 1) - sin_angles * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - + (sin_angles * (image_width - 1) + cos_angles * + (image_height - 1))) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def rotate( + images: TensorLike, + angles: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Rotate image(s) counterclockwise by the passed angle(s) in radians. + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "rotate"): + image_or_images = tf.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4D_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] + image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) + return from_4D_image(output, original_ndims) + + +def translations_to_projective_transforms(translations: TensorLike, + name: Optional[str] = None + ) -> tf.Tensor: + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or "translations_to_projective_transforms"): + translation_or_translations = tf.convert_to_tensor( + translations, name="translations", dtype=tf.dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + elif len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + elif len(translation_or_translations.get_shape()) == 2: + translations = translation_or_translations + else: + raise TypeError("Translations should have rank 1 or 2.") + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +@tf.function +def translate( + images: TensorLike, + translations: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Translate image(s) by the passed vectors(s). + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index af1b763c1..f3a946245 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,9 @@ import tensorflow as tf -#from tensorflow_addons import image as contrib_image +from .custom_tf_addons import rotate +from .custom_tf_addons import transform +from .custom_tf_addons import translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. @@ -177,19 +179,19 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = contrib_image.rotate(wrap(image), radians) + image = rotate(wrap(image), radians) return unwrap(image, replace) def translate_x(image, pixels, replace): """Equivalent of PIL Translate in X dimension.""" - image = contrib_image.translate(wrap(image), [-pixels, 0]) + image = translate(wrap(image), [-pixels, 0]) return unwrap(image, replace) def translate_y(image, pixels, replace): """Equivalent of PIL Translate in Y dimension.""" - image = contrib_image.translate(wrap(image), [0, -pixels]) + image = translate(wrap(image), [0, -pixels]) return unwrap(image, replace) @@ -199,8 +201,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = contrib_image.transform( - wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) return unwrap(image, replace) @@ -210,8 +211,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = contrib_image.transform( - wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) return unwrap(image, replace) From d6dd2e8e16145e73f69664bc81690ac06857319b Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:50:11 +0530 Subject: [PATCH 25/28] fix: resolving pylint issues in custom_tf_addons --- .../imagenet_jax/custom_tf_addons.py | 27 +++++++++---------- .../imagenet_jax/randaugment.py | 4 +-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index eda67d226..79aef6791 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -6,8 +6,7 @@ """ -import math -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import numpy as np import tensorflow as tf @@ -48,7 +47,7 @@ def get_ndims(image): return image.get_shape().ndims or tf.rank(image) -def to_4D_image(image): +def to_4d_image(image): """Convert 2/3/4D image to 4D image. Args: @@ -63,7 +62,7 @@ def to_4D_image(image): ]): ndims = image.get_shape().ndims if ndims is None: - return _dynamic_to_4D_image(image) + return _dynamic_to_4d_image(image) elif ndims == 2: return image[None, :, :, None] elif ndims == 3: @@ -72,7 +71,7 @@ def to_4D_image(image): return image -def _dynamic_to_4D_image(image): +def _dynamic_to_4d_image(image): shape = tf.shape(image) original_rank = tf.rank(image) # 4D image => [N, H, W, C] or [N, C, H, W] @@ -91,7 +90,7 @@ def _dynamic_to_4D_image(image): return tf.reshape(image, new_shape) -def from_4D_image(image, ndims): +def from_4d_image(image, ndims): """Convert back to an image with `ndims` rank. Args: @@ -105,7 +104,7 @@ def from_4D_image(image, ndims): [tf.debugging.assert_rank(image, 4, message="`image` must be 4D tensor")]): if isinstance(ndims, tf.Tensor): - return _dynamic_from_4D_image(image, ndims) + return _dynamic_from_4d_image(image, ndims) elif ndims == 2: return tf.squeeze(image, [0, 3]) elif ndims == 3: @@ -114,7 +113,7 @@ def from_4D_image(image, ndims): return image -def _dynamic_from_4D_image(image, original_rank): +def _dynamic_from_4d_image(image, original_rank): shape = tf.shape(image) # 4D image <= [N, H, W, C] or [N, C, H, W] # 3D image <= [1, H, W, C] or [1, C, H, W] @@ -183,7 +182,7 @@ def transform( transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) if output_shape is None: @@ -217,7 +216,7 @@ def transform( fill_mode=fill_mode.upper(), fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def angles_to_projective_transforms( @@ -271,7 +270,7 @@ def angles_to_projective_transforms( ) -def rotate( +def rotate_img( images: TensorLike, angles: TensorLike, interpolation: str = "nearest", @@ -286,7 +285,7 @@ def rotate( `(num_images, num_rows, num_columns, num_channels)` (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or `(num_rows, num_columns)` (HW). - angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + angles: A scalar angle to rotate all images by (if `images` has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "nearest", @@ -317,7 +316,7 @@ def rotate( image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] @@ -329,7 +328,7 @@ def rotate( fill_mode=fill_mode, fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def translations_to_projective_transforms(translations: TensorLike, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index f3a946245..dd00146cd 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,7 @@ import tensorflow as tf -from .custom_tf_addons import rotate +from .custom_tf_addons import rotate_img from .custom_tf_addons import transform from .custom_tf_addons import translate @@ -179,7 +179,7 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = rotate(wrap(image), radians) + image = rotate_img(wrap(image), radians) return unwrap(image, replace) From a0b587aed0ccecb794a46e2ba99713c56ed69f93 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 22:04:59 +0530 Subject: [PATCH 26/28] resolved pyline and changed the pylint version to current version of main --- .../imagenet_jax/custom_tf_addons.py | 20 ++++++++++++------- setup.cfg | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 79aef6791..3d6939218 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -241,12 +241,15 @@ def angles_to_projective_transforms( with tf.name_scope(name or "angles_to_projective_transforms"): angle_or_angles = tf.convert_to_tensor( angles, name="angles", dtype=tf.dtypes.float32) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError("angles should have rank 0 or 1.") + if len(angle_or_angles.get_shape()) == 0: angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles else: - raise ValueError("angles should have rank 0 or 1.") + angles = angle_or_angles + cos_angles = tf.math.cos(angles) sin_angles = tf.math.sin(angles) x_offset = ((image_width - 1) - @@ -352,12 +355,15 @@ def translations_to_projective_transforms(translations: TensorLike, if translation_or_translations.get_shape().ndims is None: raise TypeError( "translation_or_translations rank must be statically known") - elif len(translation_or_translations.get_shape()) == 1: + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError("Translations should have rank 1 or 2.") + + if len(translation_or_translations.get_shape()) == 1: translations = translation_or_translations[None] - elif len(translation_or_translations.get_shape()) == 2: - translations = translation_or_translations else: - raise TypeError("Translations should have rank 1 or 2.") + translations = translation_or_translations + num_translations = tf.shape(translations)[0] # The translation matrix looks like: # [[1 0 -dx] diff --git a/setup.cfg b/setup.cfg index a7c224407..7977267bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,7 +78,7 @@ full_dev = # Dependencies for developing the package dev = isort==5.13.2 - pylint==3.3.1 + pylint==2.16.1 pytest==8.3.3 yapf==0.32.0 pre-commit==4.0.1 From 9393145ba91b9432c1732f5bd9d8865c2cb232f8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Wed, 18 Dec 2024 20:58:42 +0530 Subject: [PATCH 27/28] fix: removing tensorflow addons from setup cfg --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7977267bd..2d246b48b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,6 @@ install_requires = pandas==2.2.3 tensorflow==2.18.0 tensorflow-datasets==4.9.7 - tensorflow-addons==0.23.0 gputil==1.4.0 psutil==6.1.0 clu==0.0.12 From 53eff1d469635408aff5d80a28f3248c4bd79464 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 20 Dec 2024 00:41:47 +0530 Subject: [PATCH 28/28] fix: adding absolute paths for custom_tf_addons in randaugment --- .../imagenet_resnet/imagenet_jax/randaugment.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index dd00146cd..e920331bc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,9 +9,12 @@ import tensorflow as tf -from .custom_tf_addons import rotate_img -from .custom_tf_addons import transform -from .custom_tf_addons import translate +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + rotate_img +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + transform +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme.