diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 05d94e896..fe2441bfe 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -25,10 +25,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -42,10 +42,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -59,10 +59,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -77,10 +77,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -96,10 +96,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -113,10 +113,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -130,10 +130,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -148,10 +148,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -166,10 +166,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -184,10 +184,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest @@ -208,10 +208,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 89b5ef288..628fc012b 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install pylint run: | python -m pip install --upgrade pip @@ -27,10 +27,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install isort run: | python -m pip install --upgrade pip @@ -43,10 +43,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install yapf run: | python -m pip install --upgrade pip diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 006b972ec..aa493bc9f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -35,7 +35,7 @@ The specs on the benchmarking machines are: > **Prerequisites:** > -> - Python minimum requirement >= 3.8 +> - Python minimum requirement >= 3.11 > - CUDA 12.1 > - NVIDIA Driver version 535.104.05 diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..04dad0eb7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -231,7 +231,7 @@ def save_checkpoint(framework: str, target=checkpoint_state, step=global_step, overwrite=True, - keep=np.Inf if save_intermediate_checkpoints else 1) + keep=np.inf if save_intermediate_checkpoints else 1) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index 9eb30861d..1f36b07bf 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -10,13 +10,13 @@ import functools import itertools import math -from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from absl import logging from numpy import random -_SweepSequence = List[Dict[Text, Any]] -_GeneratorFn = Callable[[float], Tuple[Text, float]] +_SweepSequence = List[Dict[str, Any]] +_GeneratorFn = Callable[[float], Tuple[str, float]] def generate_primes(n: int) -> List[int]: @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: Text, +def _generate_double_point(name: str, min_val: float, max_val: float, - scaling: Text, + scaling: str, halton_point: float) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]: return start, end -def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: +def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints return functools.partial(_generate_double_point, name, @@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( - name: Text, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + name: str, search_points: Union[_DiscretePoints, + Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, name, diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 609d996e6..155e55356 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict: system_software_info['os_platform'] = \ platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' system_software_info['python_version'] = platform.python_version( - ) # Ex. '3.8.10' + ) # Ex. '3.11.10' system_software_info['python_compiler'] = platform.python_compiler( ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index cf1ea6c32..b5b30ce22 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,30 +18,30 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 -MIN_INT32 = -MAX_INT32 +MAX_UINT32 = 2**32 - 1 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.PRNGKey(seed) + return jax_rng.key(seed) return _PRNGKey(seed) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 8268c6ca3..6bbf9c64b 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -75,8 +76,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -93,7 +94,7 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py new file mode 100644 index 000000000..3d6939218 --- /dev/null +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -0,0 +1,438 @@ +""" +Note: +The following code is adapted from: +https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image + + +""" + +from typing import List, Optional, Union + +import numpy as np +import tensorflow as tf + +_IMAGE_DTYPES = { + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, +} + +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +def to_4d_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D `Tensor`. + + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message="`image` must be 2/3/4D tensor") + ]): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4d_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4d_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4d_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. + + Returns: + `ndims`-D `Tensor` with the same type. + """ + with tf.control_dependencies( + [tf.debugging.assert_rank(image, 4, + message="`image` must be 4D tensor")]): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4d_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4d_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def transform( + images: TensorLike, + transforms: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Applies the given transform(s) to the image(s). + + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or "transform"): + image_or_images = tf.convert_to_tensor(images, name="images") + transform_or_transforms = tf.convert_to_tensor( + transforms, name="transforms", dtype=tf.dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4d_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + if output_shape is None: + output_shape = tf.shape(images)[1:3] + + output_shape = tf.convert_to_tensor( + output_shape, tf.dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise ValueError("transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + transforms = transform_or_transforms + raise ValueError("transforms should have rank 1 or 2, but got rank %d" % + len(transforms.get_shape())) + + fill_value = tf.convert_to_tensor( + fill_value, dtype=tf.float32, name="fill_value") + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, + ) + return from_4d_image(output, original_ndims) + + +def angles_to_projective_transforms( + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, +) -> tf.Tensor: + """Returns projective transform(s) for the given angle(s). + + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or "angles_to_projective_transforms"): + angle_or_angles = tf.convert_to_tensor( + angles, name="angles", dtype=tf.dtypes.float32) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError("angles should have rank 0 or 1.") + + if len(angle_or_angles.get_shape()) == 0: + angles = angle_or_angles[None] + else: + angles = angle_or_angles + + cos_angles = tf.math.cos(angles) + sin_angles = tf.math.sin(angles) + x_offset = ((image_width - 1) - + (cos_angles * (image_width - 1) - sin_angles * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - + (sin_angles * (image_width - 1) + cos_angles * + (image_height - 1))) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def rotate_img( + images: TensorLike, + angles: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Rotate image(s) counterclockwise by the passed angle(s) in radians. + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "rotate"): + image_or_images = tf.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4d_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] + image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) + return from_4d_image(output, original_ndims) + + +def translations_to_projective_transforms(translations: TensorLike, + name: Optional[str] = None + ) -> tf.Tensor: + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or "translations_to_projective_transforms"): + translation_or_translations = tf.convert_to_tensor( + translations, name="translations", dtype=tf.dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError("Translations should have rank 1 or 2.") + + if len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + else: + translations = translation_or_translations + + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +@tf.function +def translate( + images: TensorLike, + translations: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Translate image(s) by the passed vectors(s). + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..e920331bc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,13 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image + +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + rotate_img +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + transform +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. @@ -176,19 +182,19 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = contrib_image.rotate(wrap(image), radians) + image = rotate_img(wrap(image), radians) return unwrap(image, replace) def translate_x(image, pixels, replace): """Equivalent of PIL Translate in X dimension.""" - image = contrib_image.translate(wrap(image), [-pixels, 0]) + image = translate(wrap(image), [-pixels, 0]) return unwrap(image, replace) def translate_y(image, pixels, replace): """Equivalent of PIL Translate in Y dimension.""" - image = contrib_image.translate(wrap(image), [0, -pixels]) + image = translate(wrap(image), [0, -pixels]) return unwrap(image, replace) @@ -198,8 +204,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = contrib_image.transform( - wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) return unwrap(image, replace) @@ -209,8 +214,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = contrib_image.transform( - wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) return unwrap(image, replace) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 2747fc2db..91cdec60a 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -79,8 +80,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() # Create a shallow copy + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -111,7 +112,7 @@ def init_model_fn( input_shape = (1, 224, 224, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..5f826d035 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -4,6 +4,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax import jax.numpy as jnp @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState, variables = jax.jit( model.init)({'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") return params, model_state def init_model_fn( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index e362f973b..05faf1135 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils +from flax.core import pop import flax.linen as nn import jax from jax import lax @@ -89,7 +90,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -378,8 +379,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index e4b5cd014..97fee032f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -222,9 +222,8 @@ def __call__(self, inputs, encoder_mask=None): use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * x, - x, - encoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * x, x, mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +287,8 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask) + decode=cfg.decode)( + cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -309,9 +309,8 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * y, - encoded, - encoder_decoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b72aea86..ee9136cbf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,34 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg + +# Install prerequisites +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ + curl \ + libbz2-dev \ + liblzma-dev + +# Download and install Python 3.11 +RUN cd /tmp \ + && wget https://www.python.org/ftp/python/3.11.10/Python-3.11.10.tgz \ + && tar -xvzf Python-3.11.10.tgz \ + && cd Python-3.11.10 \ + && ./configure --enable-optimizations \ + && make -j$(nproc) \ + && make altinstall + +# Create symlinks for python and pip (use 'pip' instead of 'pip3') +RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ + && ln -s /usr/local/bin/pip3.11 /usr/bin/pip + RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz @@ -28,6 +55,8 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ +RUN pip install --upgrade pip + # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" ARG branch="main" @@ -58,8 +87,6 @@ RUN if [ "$framework" = "jax" ] ; then \ RUN cd /algorithmic-efficiency && pip install -e '.[full]' -RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull diff --git a/setup.cfg b/setup.cfg index 4afefd164..2d246b48b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,9 +21,7 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: Apache Software License Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence [options] @@ -34,22 +32,21 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==1.4.0 + absl-py==2.1.0 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 + networkx==3.2.1 + docker==7.1.0 + numpy>=2.0.2 + pandas==2.2.3 + tensorflow==2.18.0 + tensorflow-datasets==4.9.7 gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 + psutil==6.1.0 + clu==0.0.12 + matplotlib>=3.9.2 tabulate==0.9.0 -python_requires = >=3.8 + wandb==0.18.7 +python_requires = >=3.11 ############################################################################### @@ -79,78 +76,76 @@ full_dev = # Dependencies for developing the package dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 + isort==5.13.2 + pylint==2.16.1 + pytest==8.3.3 + yapf==0.32.0 + pre-commit==4.0.1 # Workloads # criteo1tb = - scikit-learn==1.2.2 + scikit-learn==1.5.2 fastmri = - h5py==3.8.0 - scikit-image==0.20.0 + h5py==3.12.1 + scikit-image==0.24.0 ogbg = jraph==0.0.6.dev0 - scikit-learn==1.2.2 + scikit-learn==1.5.2 librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 sacrebleu==1.3.1 - # Frameworks # # JAX Core jax_core_deps = - flax==0.6.10 - optax==0.1.5 + flax==0.10.1 + optax==0.2.4 # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 + chex==0.1.87 + ml_dtypes==0.4.1 + protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 + jax==0.4.35 + jaxlib==0.4.35 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 + jax==0.4.35 + jaxlib==0.4.35 + jax-cuda12-plugin[with_cuda]==0.4.35 + jax-cuda12-pjrt==0.4.35 %(jax_core_deps)s # PyTorch CPU pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.1 + torchvision==0.20.1 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.1 + torchvision==0.20.1 -# wandb -wandb = - wandb==0.16.5 ############################################################################### # Linting Configurations # diff --git a/submission_runner.py b/submission_runner.py index 9f9b8ff42..4d494f607 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -213,7 +213,7 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -342,8 +342,10 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) + data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..4ad56c873 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -3,6 +3,7 @@ import jax import numpy as np import pytest +from flax.core import FrozenDict # isort: skip_file # pylint:disable=line-too-long @@ -51,8 +52,11 @@ def test_param_shapes(workload): jax_workload, pytorch_workload = get_workload(workload) # Compare number of parameter tensors of both models. + jax_workload_param_shapes = jax_workload.param_shapes + if isinstance(jax_workload_param_shapes, dict): + jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload.param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) if workload == 'wmt':