Skip to content

Commit

Permalink
Merge pull request #898 from anarkiwi/ts
Browse files Browse the repository at this point in the history
Add experimental torch/serve container with vulkan support.
  • Loading branch information
anarkiwi authored Oct 5, 2023
2 parents 955ce95 + 46deb29 commit bf588a7
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 0 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/docker-extras.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,40 @@ on:
tags: 'v*'

jobs:
buildx-torchserve:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get the version
id: get_version
run: echo ::set-output name=VERSION::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Change for main
id: change_version
run: if [ "${{ steps.get_version.outputs.VERSION }}" == "main" ]; then echo ::set-output name=VERSION::latest; else echo ::set-output name=VERSION::${{ steps.get_version.outputs.VERSION }}; fi
- name: Set up qemu
uses: docker/setup-qemu-action@v3
with:
platforms: all
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v3
- name: Docker Login
env:
DOCKER_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
run: |
echo "${DOCKER_PASSWORD}" | docker login --username "${{ secrets.DOCKER_USERNAME }}" --password-stdin
if: github.repository == 'iqtlabs/gamutrf' && github.event_name == 'push'
- name: Build and push platforms
uses: docker/build-push-action@v5
with:
context: docker
file: docker/Dockerfile.torchserve
platforms: linux/amd64,linux/arm64
push: true
tags: iqtlabs/gamutrf-torchserve:${{ steps.change_version.outputs.VERSION }}
if: github.repository == 'iqtlabs/gamutrf' && github.event_name == 'push'
buildx-mqtt:
runs-on: ubuntu-latest
steps:
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/docker-test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
name: docker-test
on: [push, pull_request]
jobs:
test-gamutrf-torchserve-image:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: docker build
run: |
cd docker
docker build -f Dockerfile.torchserve . -t iqtlabs/gamutrf-torchserve:latest
test-gamutrf-extra-images:
runs-on: ubuntu-latest
steps:
Expand Down
73 changes: 73 additions & 0 deletions docker/Dockerfile.torchserve
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
FROM ubuntu:22.04 as shader-compiler

WORKDIR /root
RUN apt-get update && \
apt-get install -y \
build-essential \
cmake \
git \
ninja-build \
python3-dev \
python3-pip
RUN git clone https://github.com/google/shaderc -b v2023.6
WORKDIR /root/shaderc
RUN ./utils/git-sync-deps
WORKDIR /root/shaderc/build
RUN cmake -GNinja -DCMAKE_BUILD_TYPE=Release -DSHADERC_SKIP_TESTS=on .. && ninja && ninja install

FROM ubuntu:22.04 as pytorch-compiler
# ENV USE_CUDA=1
ENV USE_VULKAN=1
ENV USE_VULKAN_SHADERC_RUNTIME=1
ENV USE_VULKAN_WRAPPER=0
WORKDIR /root
RUN apt-get update && \
apt-get install -y \
build-essential \
cmake \
git \
libvulkan-dev \
ninja-build \
python3-dev \
python3-pip
COPY --from=shader-compiler /usr/local /usr/local
RUN git clone https://github.com/pytorch/pytorch -b v2.1.0
WORKDIR /root/pytorch
COPY pytorch-patch.txt /root/pytorch
RUN patch -p1 < pytorch-patch.txt
RUN pip3 install -U pyyaml
RUN python3 setup.py install && python3 setup.py develop
RUN python3 -c "import torch ; assert(torch.is_vulkan_available())"

FROM ubuntu:22.04 as torchserve-builder
RUN apt-get update && \
apt-get install -y \
build-essential \
cmake \
git \
libvulkan-dev \
ninja-build \
openjdk-17-jdk \
python3-dev \
python3-pip
WORKDIR /root
RUN git clone https://github.com/pytorch/serve -b v0.8.2
WORKDIR /root/serve
COPY torchserve-patch.txt /root/serve
RUN patch -p1 < torchserve-patch.txt
COPY --from=pytorch-compiler /usr/local /usr/local
RUN python3 -c "import torch ; assert(torch.is_vulkan_available())"
RUN python3 ./ts_scripts/install_dependencies.py --environment prod
RUN pip3 install .

FROM ubuntu:22.04
RUN apt-get update && \
apt-get install -y \
libvulkan1 \
openjdk-17-jre \
python3 \
python3-pip
WORKDIR /root
COPY --from=torchserve-builder /usr/local /usr/local
RUN python3 -c "import torch ; assert(torch.is_vulkan_available())"
RUN /usr/local/bin/torchserve --help
25 changes: 25 additions & 0 deletions docker/pytorch-patch.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
diff --git a/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp b/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp
index ef6b794fbd9..900dd8ddf4c 100644
--- a/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp
+++ b/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp
@@ -17,6 +17,7 @@ api::ShaderInfo get_shader(const OpType type) {
case OpType::DIV:
return VK_KERNEL(div);
}
+ return VK_KERNEL(add);
}

struct Params final {
diff --git a/aten/src/ATen/native/vulkan/impl/Registry.cpp b/aten/src/ATen/native/vulkan/impl/Registry.cpp
index 3cf3148c874..b750914dfeb 100644
--- a/aten/src/ATen/native/vulkan/impl/Registry.cpp
+++ b/aten/src/ATen/native/vulkan/impl/Registry.cpp
@@ -33,7 +33,7 @@ const api::ShaderInfo& look_up_shader_info(const std::string& op_name) {
const RegistryKeyMap& registry_key_map = registry_iterator->second;

// Look for "override" and "catchall" keys
- for (const std::string& key : {"override", "catchall"}) {
+ for (const std::string key : {"override", "catchall"}) {
const RegistryKeyMap::const_iterator registry_key_iterator =
registry_key_map.find(key);
if (registry_key_iterator != registry_key_map.end()) {
25 changes: 25 additions & 0 deletions docker/torchserve-patch.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
diff --git a/requirements/torch_cu118_linux.txt b/requirements/torch_cu118_linux.txt
index d34969ef..fe823925 100644
--- a/requirements/torch_cu118_linux.txt
+++ b/requirements/torch_cu118_linux.txt
@@ -1,7 +1,7 @@
#pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
--extra-index-url https://download.pytorch.org/whl/cu118
-r torch_common.txt
-torch==2.0.1+cu118; sys_platform == 'linux'
+# torch==2.0.1+cu118; sys_platform == 'linux'
torchvision==0.15.2+cu118; sys_platform == 'linux'
torchtext==0.15.2; sys_platform == 'linux'
torchaudio==2.0.2+cu118; sys_platform == 'linux'
diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py
index 6aef56db..0875d4c6 100644
--- a/ts_scripts/install_dependencies.py
+++ b/ts_scripts/install_dependencies.py
@@ -23,6 +23,7 @@ class Common:
pass

def install_torch_packages(self, cuda_version):
+ return
if cuda_version:
if platform.system() == "Darwin":
print(

0 comments on commit bf588a7

Please sign in to comment.