-
Notifications
You must be signed in to change notification settings - Fork 2
/
compile_torch.sh
executable file
·158 lines (134 loc) · 4.58 KB
/
compile_torch.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/bin/bash
# simple usage: work_path=/home/yhao/pt ./compile_torch.sh
# make sure you have activated the correct conda environment before running this script
# Add strict mode for better error handling
set -euo pipefail
IFS=$'\n\t'
# Consolidate and organize environment variables at the top
declare -r MAX_JOBS=256
declare -r DEFAULT_WORK_PATH="/home/yhao/p9"
# Convert environment variables to more robust declarations
declare -r work_path=${work_path:-"$DEFAULT_WORK_PATH"}
declare -r clean_install=${clean_install:-0}
declare -r clean_upgrade=${clean_upgrade:-0}
declare -r clean_torch=${clean_torch:-0}
declare -r torch_only=${torch_only:-0}
declare -r debug=${debug:-0}
declare -r torch_commit=${torch_commit:-""}
declare -r torch_branch=${torch_branch:-"main"}
declare -r torch_pull=${torch_pull:-0}
declare -r no_torchbench=${no_torchbench:-0}
# GPU-related exports
export USE_ROCM=0
export USE_NCCL=1
# Improve error handling function
function error_exit() {
local message="$1"
echo "ERROR: $message" >&2
exit 1
}
# Improve the git_upgrade_pack function with error handling
function git_upgrade_pack() {
local package_name="$1"
echo "Upgrading package: $package_name"
cd "$work_path/$package_name" || error_exit "Failed to change directory to $package_name"
git pull || error_exit "Failed to pull latest changes for $package_name"
git submodule sync || error_exit "Failed to sync submodules for $package_name"
git submodule update --init --recursive || error_exit "Failed to update submodules for $package_name"
}
# Improve the upgrade_pack function
function upgrade_pack() {
local package_name="$1"
echo "Installing package: $package_name"
git_upgrade_pack "$package_name"
pip uninstall -y "$package_name" || true # Don't fail if package isn't installed
python setup.py clean || error_exit "Failed to clean $package_name"
python setup.py install || error_exit "Failed to install $package_name"
echo "$package_name installation completed successfully"
}
# print configs
echo "work_path: ${work_path}"
echo "clean_install: ${clean_install}"
echo "clean_torch: ${clean_torch}"
echo "torch_only: ${torch_only}"
echo "torch_branch: ${torch_branch}"
echo "torch_commit: ${torch_commit}"
# if you have an error named like version `GLIBCXX_3.4.30' not found, you can add `-c conda-forge` to the following command. And also for your `conda create -n pt_compiled -c conda-forge python=3.10` command
conda install -y magma-cuda121 -c pytorch
conda install -y ccache cmake ninja mkl mkl-include libpng libjpeg-turbo -c conda-forge
# graphviz
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
# Improve directory handling
cd "$work_path" || error_exit "Failed to change to work directory"
# Improve clean install section with error handling
if [ "$clean_install" -eq 1 ]; then
echo "Performing clean installation..."
rm -rf pytorch text vision audio benchmark data
for repo in pytorch text data vision audio benchmark; do
done
fi
function notify_finish() {
echo "PyTorch compilation completed successfully"
if command -v notify &>/dev/null; then
notify "PyTorch Compilation is done" || true # Don't fail if notify fails
fi
}
pip uninstall -y torch
# install pytorch
cd $work_path/pytorch
git fetch
if [ -n "$torch_commit" ]; then
git checkout $torch_commit
echo "warnging: you are using a specific commit. don't forget to create a new branch if you want to make changes"
else
git checkout $torch_branch
fi
if [ $torch_pull -eq 1 ]; then
git pull
fi
git submodule sync
git submodule update --init --recursive
pip install -r requirements.txt
make triton
if [ $clean_torch -eq 1 ]; then
python setup.py clean
fi
if [ $debug -eq 1 ]; then
debug_prefix="env DEBUG=1"
else
debug_prefix=""
fi
${debug_prefix} python setup.py develop
if [ $torch_only -eq 1 ]; then
notify_finish
exit 0
fi
# install torchdata
cd $work_path
upgrade_pack data
# install torchtext
cd $work_path
export CC=$(which gcc)
export CXX=$(which g++)
upgrade_pack text
# install torchvision
export FORCE_CUDA=1
upgrade_pack vision
# install torchaudio
upgrade_pack audio
if [ $no_torchbench -eq 1 ]; then
notify_finish
exit 0
fi
# install torchbench
pip install pyyaml
cd $work_path/benchmark
git pull
git submodule sync
git submodule update --init --recursive
python install.py
echo "torchbench installation is done"
notify_finish
# Add trap for cleanup on script exit
trap 'echo "Script execution interrupted"; exit 1' INT TERM