diff --git a/Dockerfiles/compose.rocm.yaml b/Dockerfiles/compose.rocm.yaml index cf7dd12b..256ec6ad 100644 --- a/Dockerfiles/compose.rocm.yaml +++ b/Dockerfiles/compose.rocm.yaml @@ -12,7 +12,7 @@ services: container_name: reGen tty: true #environment: - # - FLASH_ATTENTION_USE_TRITON_ROCM=TRUE + # - FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE group_add: - video volumes: diff --git a/Dockerfiles/entrypoint.sh b/Dockerfiles/entrypoint.sh index e520e146..a392f2d9 100644 --- a/Dockerfiles/entrypoint.sh +++ b/Dockerfiles/entrypoint.sh @@ -24,7 +24,7 @@ if [ ! -z "${ROCM_VERSION_SHORT}" ]; then # Determine if the user has a flash attention supported card. SUPPORTED_CARD=$(rocminfo | grep -c -e gfx1100 -e gfx1101 -e gfx1102) - if [ "$SUPPORTED_CARD" -gt 0 ]; then export FLASH_ATTENTION_USE_TRITON_ROCM="${FLASH_ATTENTION_USE_TRITON_ROCM:=TRUE}"; fi + if [ "$SUPPORTED_CARD" -gt 0 ]; then export FLASH_ATTENTION_TRITON_AMD_ENABLE="${FLASH_ATTENTION_TRITON_AMD_ENABLE:=TRUE}"; fi #export PYTORCH_TUNABLEOP_ENABLED=1 export MIOPEN_FIND_MODE="FAST" diff --git a/horde-bridge-rocm.sh b/horde-bridge-rocm.sh index e90bf8d2..1006f6c7 100755 --- a/horde-bridge-rocm.sh +++ b/horde-bridge-rocm.sh @@ -7,7 +7,7 @@ CONDA_ENV_PATH="$SCRIPT_DIR/conda/envs/linux/lib" # Determine if the user has a flash attention supported card. SUPPORTED_CARD=$(rocminfo | grep -c -e gfx1100 -e gfx1101 -e gfx1102) -if [ "$SUPPORTED_CARD" -gt 0 ]; then export FLASH_ATTENTION_USE_TRITON_ROCM="${FLASH_ATTENTION_USE_TRITON_ROCM:=TRUE}"; fi +if [ "$SUPPORTED_CARD" -gt 0 ]; then export FLASH_ATTENTION_TRITON_AMD_ENABLE="${FLASH_ATTENTION_TRITON_AMD_ENABLE:=TRUE}"; fi export MIOPEN_FIND_MODE="FAST" # Check if we are running in WSL2 diff --git a/horde_worker_regen/amd_go_fast/install_amd_go_fast.sh b/horde_worker_regen/amd_go_fast/install_amd_go_fast.sh index dcfcc5a7..8fc33c52 100755 --- a/horde_worker_regen/amd_go_fast/install_amd_go_fast.sh +++ b/horde_worker_regen/amd_go_fast/install_amd_go_fast.sh @@ -1,6 +1,6 @@ #!/bin/bash -if [ "${FLASH_ATTENTION_USE_TRITON_ROCM^^}" == "TRUE" ]; then +if [ "${FLASH_ATTENTION_TRITON_AMD_ENABLE^^}" == "TRUE" ]; then if ! pip install -U pytest git+https://github.com/Dao-AILab/flash-attention; then echo "Tried to install flash attention and failed!" else diff --git a/update-runtime-rocm.sh b/update-runtime-rocm.sh index 31c02656..9d8b7419 100755 --- a/update-runtime-rocm.sh +++ b/update-runtime-rocm.sh @@ -29,7 +29,7 @@ CONDA_ENVIRONMENT_FILE=environment.rocm.yaml # Determine if the user has a flash attention supported card. SUPPORTED_CARD=$(rocminfo | grep -c -e gfx1100 -e gfx1101 -e gfx1102) -if [ "$SUPPORTED_CARD" -gt 0 ]; then export FLASH_ATTENTION_USE_TRITON_ROCM="${FLASH_ATTENTION_USE_TRITON_ROCM:=TRUE}"; fi +if [ "$SUPPORTED_CARD" -gt 0 ]; then export FLASH_ATTENTION_TRITON_AMD_ENABLE="${FLASH_ATTENTION_TRITON_AMD_ENABLE:=TRUE}"; fi wget -qO- https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-linux-64.tar.bz2 | tar -xvj -C "${SCRIPT_DIR}" if [ ! -f "$SCRIPT_DIR/conda/envs/linux/bin/python" ]; then