Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rocm-ci.yaml #76

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ concurrency:

jobs:
lint_and_typecheck:
if: false
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
Expand All @@ -37,6 +38,7 @@ jobs:
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/[email protected]

build:
if: false
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
runs-on: ${{ matrix.os }}
timeout-minutes: 60
Expand All @@ -46,13 +48,13 @@ jobs:
include:
- name-prefix: "with 3.10"
python-version: "3.10"
os: ubuntu-20.04-16core
os: ubuntu-20.04-16core # update to custom rocm runner
enable-x64: 1
prng-upgrade: 1
num_generated_cases: 1
- name-prefix: "with 3.12"
python-version: "3.12"
os: ubuntu-20.04-16core
os: ubuntu-20.04-16core # Update to customer rocm runner
enable-x64: 0
prng-upgrade: 0
num_generated_cases: 1
Expand Down Expand Up @@ -97,6 +99,7 @@ jobs:


documentation:
if: false
name: Documentation - test code snippets
runs-on: ubuntu-latest
timeout-minutes: 10
Expand Down Expand Up @@ -134,6 +137,7 @@ jobs:


documentation_render:
if: false
name: Documentation - render documentation
runs-on: ubuntu-latest
timeout-minutes: 10
Expand Down Expand Up @@ -165,6 +169,7 @@ jobs:


jax2tf_test:
if: false
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
runs-on: ${{ matrix.os }}
timeout-minutes: 30
Expand Down
38 changes: 38 additions & 0 deletions .github/workflows/rocm-ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: ROCM CI

on: [push]

jobs:
build-docker: # strategy and matrix come here
runs-on: mi-250
steps:
- name: Print Environment Variables
run: printenv
- uses: actions/checkout@v4
- name: Build Docker
env:
DOCKER_IMG_NAME: rocm_jax_r6_1_3_py3_10_id${{ github.run_id }}
# XLA_CLONE_DIR:
run: |
./build/rocm/ci_build.sh --rocm_version 6.1.3 \
--keep_image --py_version 3.10
- name: Archive jax wheels
uses: actions/upload-artifact@v4
with:
name: rocm_jax_r6_1_3_py3_10_id${{ github.run_id }}
path: ${{ github.workspace }}/wheelhouse/*.whl
- name: Detect GPUs
run: |
docker run \
--group-add video \
$RENDER_DEVICES \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--shm-size=64G \
rocm/rocm-terminal rocm-smi
- name: Run tests
env:
DOCKER_IMG_NAME: rocm_jax_r6_1_3_py3_10_id${{ github.run_id }}
run: |
./build/rocm/ci_build test ${DOCKER_IMG_NAME}

29 changes: 22 additions & 7 deletions build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,15 @@ def dist_wheels(

cmd = ["docker", "run"]

# docker run fails when mounting ./
own_path = os.path.dirname(os.path.abspath(__file__))
repo_path = os.path.abspath(os.path.join(own_path, "..",".."))
whl_path = os.path.join(repo_path, "wheelhouse")
mounts = [
"-v",
"./:/jax",
"%s:/jax" % repo_path,
"-v",
"./wheelhouse:/wheelhouse",
"%s:/wheelhouse" % whl_path,
]

if xla_path:
Expand Down Expand Up @@ -130,10 +134,16 @@ def _fetch_jax_metadata(xla_path):

jax_version = subprocess.check_output(cmd, env=env)

def safe_decode(x):
if isinstance(x, str):
return x
else:
return x.decode("utf8")

return {
"jax_version": jax_version.decode("utf8").strip(),
"jax_commit": jax_commit.decode("utf8").strip(),
"xla_commit": xla_commit.decode("utf8").strip(),
"jax_version": safe_decode(jax_version).strip(),
"jax_commit": safe_decode(jax_commit).strip(),
"xla_commit": safe_decode(xla_commit).strip(),
}


Expand Down Expand Up @@ -198,15 +208,20 @@ def test(image_name):
cmd = [
"docker",
"run",
"-it",
"--rm",
]

if os.isatty(sys.stdout.fileno()):
cmd.append("-it")

# NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels
# docker run fails when mounting ./
own_path = os.path.dirname(os.path.abspath(__file__))
repo_path = os.path.abspath(os.path.join(own_path, "..",".."))
mounts = [
"-v",
"./:/jax",
"%s:/jax" % repo_path,
]

cmd.extend(mounts)
Expand Down
4 changes: 3 additions & 1 deletion build/rocm/ci_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@ WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
BUILD_TAG="${BUILD_TAG:-jax}"

# Determine the docker image name and BUILD_TAG.
DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}"
DOCKER_IMG_NAME_DEFAULT="${BUILD_TAG}.${CONTAINER_TYPE}"

# Let the env override the image name
DOCKER_IMG_NAME="${DOCKER_IMG_NAME:-$DOCKER_IMG_NAME_DEFAULT}"
# Under Jenkins matrix build, the build tag may contain characters such as
# commas (,) and equal signs (=), which are not valid inside docker image names.
DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g')
Expand Down
17 changes: 17 additions & 0 deletions build/rocm/tools/build_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import subprocess
import shutil
import sys
import stat


LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -254,6 +255,22 @@ def main():
if os.path.basename(whl).startswith("jax-"):
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir)
# delete the 'dist' directory since it causes permissions issues
logging.info('Deleting dist, egg-info and cache directory')
shutil.rmtree(os.path.join(args.jax_path, "dist"))
shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info"))
shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__"))

# make the wheels delete-abl by the runner
whl_house = os.path.join(args.jax_path, "wheelhouse")
logging.info(f'Changing permissions for {whl_house}')
mode = (stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR |
stat.S_IRGRP | stat.S_IWGRP | stat.S_IXGRP |
stat.S_IROTH | stat.S_IWOTH | stat.S_IXOTH )
for item in os.listdir(whl_house):
whl_path = os.path.join(whl_house, item)
if os.path.isfile(whl_path):
os.chmod(whl_path, mode)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion build/rocm/tools/get_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,12 @@ def setup_repos_el8(rocm_version_str):
"""
[amdgpu]
name=amdgpu
baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/
baseurl=https://repo.radeon.com/amdgpu/%s/rhel/8.8/main/x86_64/
enabled=1
gpgcheck=1
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
"""
% rocm_version_str
)


Expand Down
Loading