fix env.sh

only try to auto-detect CUDA/ROCm if torch is installed
This commit is contained in:
mauwii 2023-02-05 00:58:42 +01:00 committed by Matthias Wild
parent 6089f33e54
commit 1386d73302

View File

@ -1,24 +1,26 @@
#!/usr/bin/env bash #!/usr/bin/env bash
if [[ -z "$PIP_EXTRA_INDEX_URL" ]]; then if python -c "import torch" &>/dev/null; then
# Decide which container flavor to build if not specified if [[ -z "$PIP_EXTRA_INDEX_URL" ]]; then
if [[ -z "$CONTAINER_FLAVOR" ]]; then # Decide which container flavor to build if not specified
# Check for CUDA and ROCm if [[ -z "$CONTAINER_FLAVOR" ]]; then
CUDA_AVAILABLE=$(python -c "import torch;print(torch.cuda.is_available())") # Check for CUDA and ROCm
ROCM_AVAILABLE=$(python -c "import torch;print(torch.version.hip is not None)") CUDA_AVAILABLE=$(python -c "import torch;print(torch.cuda.is_available())")
if [[ "$(uname -s)" != "Darwin" && "${CUDA_AVAILABLE}" == "True" ]]; then ROCM_AVAILABLE=$(python -c "import torch;print(torch.version.hip is not None)")
CONTAINER_FLAVOR="cuda" if [[ "$(uname -s)" != "Darwin" && "${CUDA_AVAILABLE}" == "True" ]]; then
elif [[ "$(uname -s)" != "Darwin" && "${ROCM_AVAILABLE}" == "True" ]]; then CONTAINER_FLAVOR="cuda"
CONTAINER_FLAVOR="rocm" elif [[ "$(uname -s)" != "Darwin" && "${ROCM_AVAILABLE}" == "True" ]]; then
else CONTAINER_FLAVOR="rocm"
CONTAINER_FLAVOR="cpu" else
CONTAINER_FLAVOR="cpu"
fi
fi
# Set PIP_EXTRA_INDEX_URL based on container flavor
if [[ "$CONTAINER_FLAVOR" == "rocm" ]]; then
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm"
elif [[ "$CONTAINER_FLAVOR" == "cpu" ]]; then
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
fi fi
fi
# Set PIP_EXTRA_INDEX_URL based on container flavor
if [[ "$CONTAINER_FLAVOR" == "rocm" ]]; then
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm"
elif [[ "$CONTAINER_FLAVOR" == "cpu" ]]; then
PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
fi fi
fi fi
@ -30,6 +32,7 @@ PLATFORM="${PLATFORM-Linux/${ARCH}}"
INVOKEAI_BRANCH="${INVOKEAI_BRANCH-$(git branch --show)}" INVOKEAI_BRANCH="${INVOKEAI_BRANCH-$(git branch --show)}"
CONTAINER_REGISTRY="${CONTAINER_REGISTRY-"ghcr.io"}" CONTAINER_REGISTRY="${CONTAINER_REGISTRY-"ghcr.io"}"
CONTAINER_REPOSITORY="${CONTAINER_REPOSITORY-"$(whoami)/${REPOSITORY_NAME}"}" CONTAINER_REPOSITORY="${CONTAINER_REPOSITORY-"$(whoami)/${REPOSITORY_NAME}"}"
CONTAINER_FLAVOR="${CONTAINER_FLAVOR-cuda}"
CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}" CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}"
CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}" CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}"
CONTAINER_IMAGE="${CONTAINER_IMAGE,,}" CONTAINER_IMAGE="${CONTAINER_IMAGE,,}"