Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update manifest and pin newer numpy in containers #444

Merged
merged 4 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

RUN mkdir -p /opt/pip-tools.d
RUN <<"EOF" bash -ex
# Encourage a newer numpy so that pip's dependency resolver will allow newer
# versions of other packages that rely on newer numpy, but also include fixes
# for compatibility with newer JAX versions. e.g. chex.
echo "numpy >= 1.24.1" >> /opt/pip-tools.d/requirements-jax.in
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in
EOF
Expand Down
29 changes: 15 additions & 14 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,50 @@
jax:
url: https://github.com/google/jax.git
tracking_ref: main
latest_verified_commit: 595117b70c11055e569480b80907d8c8a9901805
latest_verified_commit: afa2f1e420de3d2cfd684cff080a3808ee66daf5
mode: git-clone
xla:
url: https://github.com/openxla/xla.git
tracking_ref: main
latest_verified_commit: 78a5297d8e4301cb3ba2514061f56f89104e3d88
latest_verified_commit: 64a7946ffd048daf65ef330fc4ca5e4c3c1482a0
mode: git-clone
flax:
url: https://github.com/google/flax.git
mirror_url: https://github.com/nvjax-svc-0/flax.git
tracking_ref: main
latest_verified_commit: 230b0d77e98da22b6e574c3cbff743ca1504bfca
latest_verified_commit: 85dfad242e56098849dbf05e7e4657b3a40820f9
mode: git-clone
patches:
pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
transformer-engine:
url: https://github.com/NVIDIA/TransformerEngine.git
tracking_ref: main
latest_verified_commit: 92c1e500dd14608e54f75df8276baa1104c61d48
latest_verified_commit: d155eaac8e08d42e67d7efd812ee2a69954de816
mode: git-clone
t5x:
url: https://github.com/google-research/t5x.git
mirror_url: https://github.com/nvjax-svc-0/t5x.git
tracking_ref: main
latest_verified_commit: 1bfd2f15e5e77b09d60301367f67fdc9bb756b46
latest_verified_commit: dbc4b6f426862d5a742a2104a17524f53dd442f0
mode: git-clone
patches:
mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore
mirror/patch/dali-support: file://patches/t5x/mirror-patch-dali-support.patch # pull/1393/head # https://github.com/google-research/t5x/pull/1393: Adds DALI support to t5x
mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100)
# mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100)
mirror/ashors/fix_rng_dtype: file://patches/t5x/mirror-ashors-fix_rng_dtype.patch # fix on top of (and incorporating) https://github.com/google-research/t5x/pull/1391
paxml:
url: https://github.com/google/paxml.git
mirror_url: https://github.com/nvjax-svc-0/paxml.git
tracking_ref: main
latest_verified_commit: 7ae682d4d99630008e190b96c5296990297175c2
latest_verified_commit: 60e9e29bd3c6cc53bb4462f8c03bd5408daacd7b
mode: git-clone
patches:
pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support
praxis:
url: https://github.com/google/praxis.git
mirror_url: https://github.com/nvjax-svc-0/praxis.git
tracking_ref: main
latest_verified_commit: b6f32fa0fc6721db1cec75972b0f569c82095956
latest_verified_commit: 5b70196ffba154e78a5f78ce9175854b18cf936d
mode: git-clone
patches:
pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas.
Expand All @@ -53,7 +54,7 @@ lingvo:
# Used only in ARM pax builds
url: https://github.com/tensorflow/lingvo.git
tracking_ref: master
latest_verified_commit: 0274fa20b4ff194c1c118b94b5f778caa5d9a84a
latest_verified_commit: ab71210c31706b190ebdd3bd3573ed833e693587
mode: git-clone
tensorflow-text:
# Used only in ARM pax builds
Expand All @@ -68,18 +69,18 @@ pydantic:
fiddle:
url: https://github.com/google/fiddle.git
tracking_ref: main
latest_verified_commit: d409cf95164599a88e49d2b6a23a0972a7170b0b
latest_verified_commit: a9e98709d4b109bf04ed61cee5ff366c50c82463
mode: pip-vcs
# Used by t5x
airio:
url: https://github.com/google/airio.git
tracking_ref: main
latest_verified_commit: 69b3ec4ded478ad9cacdc97652a9d086a6a644c4
latest_verified_commit: 0e31f368b12d298e133b3a774e27d9bb0e85d087
mode: pip-vcs
clu:
url: https://github.com/google/CommonLoopUtils.git
tracking_ref: main
latest_verified_commit: 7ba2a9d83a3bc1a97b59482c2f02dc4b3614bc31
latest_verified_commit: f30bc441a14f0ccf8eaff79800f486a846613a8c
mode: pip-vcs
dllogger:
url: https://github.com/NVIDIA/dllogger.git
Expand All @@ -94,10 +95,10 @@ jestimator:
optax:
url: https://github.com/deepmind/optax.git
tracking_ref: master
latest_verified_commit: bf987e15eacf6efeb1a1a51b8868c094c3a15f9b
latest_verified_commit: bc22961422eb2397a4639ec945da0bea73d624d6
mode: pip-vcs
seqio:
url: https://github.com/google/seqio.git
tracking_ref: main
latest_verified_commit: 515d917bf58da4103a2bbf39c3716213c36aff03
latest_verified_commit: b582c96cb83f1472925c2b50b90059ad1da8c138
mode: pip-vcs
4 changes: 2 additions & 2 deletions .github/container/patches/flax/PR-3340.patch
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ index 076fd680..6eff2dd1 100644


--
2.25.1
2.43.0


From d1f3ec337b85b5c5377aab72d814adfc89dd4af5 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -436,5 +436,5 @@ index 999acf2c..8e031c77 100644
else:
bias = None
--
2.25.1
2.43.0

18 changes: 9 additions & 9 deletions .github/container/patches/paxml/PR-46.patch
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ index 587181d..e7fe54a 100644

train_state_partition_specs = (
--
2.25.1
2.43.0


From 9d6b6db6039d7e6658dd179e5838379c7dc967e3 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -717,7 +717,7 @@ index d44ca67..2b9dba4 100644
assert self.packed_input == False
assert len(self.moe_layers) == 0
--
2.25.1
2.43.0


From 1612dc7a1f77f0a515eb4801087a8b4f0756e5b9 Mon Sep 17 00:00:00 2001
Expand All @@ -744,7 +744,7 @@ index 2b9dba4..ef20305 100644
return x_out

--
2.25.1
2.43.0


From 71507dc4b1396252e6fa746d1299854c204f0c51 Mon Sep 17 00:00:00 2001
Expand All @@ -771,7 +771,7 @@ index e7fe54a..4093c3b 100644
vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt(
mdl_vars, excluded_for_learner
--
2.25.1
2.43.0


From 2a8233302c7e42b7dc7628c41abb637518d15c29 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -808,7 +808,7 @@ index ef20305..fed1601 100644
finally:
pass
--
2.25.1
2.43.0


From 2a6e5a960f438653b4c9cbeb0c016225af853279 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -975,7 +975,7 @@ index fed1601..5914e54 100644
def update_fp8_metas_if_needed(mdl_vars, grads):
return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads)
--
2.25.1
2.43.0


From b57188225e7890dfc54d70db7d89fcb32e61e762 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1136,7 +1136,7 @@ index 4093c3b..2e8fc35 100644
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
--
2.25.1
2.43.0


From c43766ee2e8cda686176a3895e87150b10d5de5e Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1528,7 +1528,7 @@ index fd482df..b271258 100644
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
--
2.25.1
2.43.0


From abc0fabc3e2ffb42d1f62254ad42448a39cbd128 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1563,5 +1563,5 @@ index b271258..cbac7cf 100644

class TransformerEngineHelperBase:
--
2.25.1
2.43.0

2 changes: 1 addition & 1 deletion .github/container/patches/praxis/PR-27.patch
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ index a35ce8b..52886bc 100644
self.add_summary('attention_mask', atten_mask)
if self.attention_extra_logit is None:
--
2.25.1
2.43.0

4 changes: 2 additions & 2 deletions .github/container/patches/praxis/PR-36.patch
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ index ab6cff3..c79dac9 100644
# Annotate the inputs before the pipeline to prevent unexpected
# propagation from earlier layers.
--
2.25.1
2.43.0


From ff1745796009cf1ec59f463f8e776c66f1286938 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -358,5 +358,5 @@ index e3b2f7c..b31526e 100644
trans_in_fn=_get_to_f32_converter(bf16_vars_to_convert),
trans_out_fn=_get_to_bf16_converter(bf16_vars_to_convert),
--
2.25.1
2.43.0

Loading
Loading