-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #898 from anarkiwi/ts
Add experimental torch/serve container with vulkan support.
- Loading branch information
Showing
5 changed files
with
165 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
FROM ubuntu:22.04 as shader-compiler | ||
|
||
WORKDIR /root | ||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
build-essential \ | ||
cmake \ | ||
git \ | ||
ninja-build \ | ||
python3-dev \ | ||
python3-pip | ||
RUN git clone https://github.com/google/shaderc -b v2023.6 | ||
WORKDIR /root/shaderc | ||
RUN ./utils/git-sync-deps | ||
WORKDIR /root/shaderc/build | ||
RUN cmake -GNinja -DCMAKE_BUILD_TYPE=Release -DSHADERC_SKIP_TESTS=on .. && ninja && ninja install | ||
|
||
FROM ubuntu:22.04 as pytorch-compiler | ||
# ENV USE_CUDA=1 | ||
ENV USE_VULKAN=1 | ||
ENV USE_VULKAN_SHADERC_RUNTIME=1 | ||
ENV USE_VULKAN_WRAPPER=0 | ||
WORKDIR /root | ||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
build-essential \ | ||
cmake \ | ||
git \ | ||
libvulkan-dev \ | ||
ninja-build \ | ||
python3-dev \ | ||
python3-pip | ||
COPY --from=shader-compiler /usr/local /usr/local | ||
RUN git clone https://github.com/pytorch/pytorch -b v2.1.0 | ||
WORKDIR /root/pytorch | ||
COPY pytorch-patch.txt /root/pytorch | ||
RUN patch -p1 < pytorch-patch.txt | ||
RUN pip3 install -U pyyaml | ||
RUN python3 setup.py install && python3 setup.py develop | ||
RUN python3 -c "import torch ; assert(torch.is_vulkan_available())" | ||
|
||
FROM ubuntu:22.04 as torchserve-builder | ||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
build-essential \ | ||
cmake \ | ||
git \ | ||
libvulkan-dev \ | ||
ninja-build \ | ||
openjdk-17-jdk \ | ||
python3-dev \ | ||
python3-pip | ||
WORKDIR /root | ||
RUN git clone https://github.com/pytorch/serve -b v0.8.2 | ||
WORKDIR /root/serve | ||
COPY torchserve-patch.txt /root/serve | ||
RUN patch -p1 < torchserve-patch.txt | ||
COPY --from=pytorch-compiler /usr/local /usr/local | ||
RUN python3 -c "import torch ; assert(torch.is_vulkan_available())" | ||
RUN python3 ./ts_scripts/install_dependencies.py --environment prod | ||
RUN pip3 install . | ||
|
||
FROM ubuntu:22.04 | ||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
libvulkan1 \ | ||
openjdk-17-jre \ | ||
python3 \ | ||
python3-pip | ||
WORKDIR /root | ||
COPY --from=torchserve-builder /usr/local /usr/local | ||
RUN python3 -c "import torch ; assert(torch.is_vulkan_available())" | ||
RUN /usr/local/bin/torchserve --help |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
diff --git a/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp b/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp | ||
index ef6b794fbd9..900dd8ddf4c 100644 | ||
--- a/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp | ||
+++ b/aten/src/ATen/native/vulkan/impl/Arithmetic.cpp | ||
@@ -17,6 +17,7 @@ api::ShaderInfo get_shader(const OpType type) { | ||
case OpType::DIV: | ||
return VK_KERNEL(div); | ||
} | ||
+ return VK_KERNEL(add); | ||
} | ||
|
||
struct Params final { | ||
diff --git a/aten/src/ATen/native/vulkan/impl/Registry.cpp b/aten/src/ATen/native/vulkan/impl/Registry.cpp | ||
index 3cf3148c874..b750914dfeb 100644 | ||
--- a/aten/src/ATen/native/vulkan/impl/Registry.cpp | ||
+++ b/aten/src/ATen/native/vulkan/impl/Registry.cpp | ||
@@ -33,7 +33,7 @@ const api::ShaderInfo& look_up_shader_info(const std::string& op_name) { | ||
const RegistryKeyMap& registry_key_map = registry_iterator->second; | ||
|
||
// Look for "override" and "catchall" keys | ||
- for (const std::string& key : {"override", "catchall"}) { | ||
+ for (const std::string key : {"override", "catchall"}) { | ||
const RegistryKeyMap::const_iterator registry_key_iterator = | ||
registry_key_map.find(key); | ||
if (registry_key_iterator != registry_key_map.end()) { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
diff --git a/requirements/torch_cu118_linux.txt b/requirements/torch_cu118_linux.txt | ||
index d34969ef..fe823925 100644 | ||
--- a/requirements/torch_cu118_linux.txt | ||
+++ b/requirements/torch_cu118_linux.txt | ||
@@ -1,7 +1,7 @@ | ||
#pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 | ||
--extra-index-url https://download.pytorch.org/whl/cu118 | ||
-r torch_common.txt | ||
-torch==2.0.1+cu118; sys_platform == 'linux' | ||
+# torch==2.0.1+cu118; sys_platform == 'linux' | ||
torchvision==0.15.2+cu118; sys_platform == 'linux' | ||
torchtext==0.15.2; sys_platform == 'linux' | ||
torchaudio==2.0.2+cu118; sys_platform == 'linux' | ||
diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py | ||
index 6aef56db..0875d4c6 100644 | ||
--- a/ts_scripts/install_dependencies.py | ||
+++ b/ts_scripts/install_dependencies.py | ||
@@ -23,6 +23,7 @@ class Common: | ||
pass | ||
|
||
def install_torch_packages(self, cuda_version): | ||
+ return | ||
if cuda_version: | ||
if platform.system() == "Darwin": | ||
print( |