diff --git a/docker/env.sh b/docker/env.sh index 2196d78425..6987e12fb8 100644 --- a/docker/env.sh +++ b/docker/env.sh @@ -1,24 +1,26 @@ #!/usr/bin/env bash -if [[ -z "$PIP_EXTRA_INDEX_URL" ]]; then - # Decide which container flavor to build if not specified - if [[ -z "$CONTAINER_FLAVOR" ]]; then - # Check for CUDA and ROCm - CUDA_AVAILABLE=$(python -c "import torch;print(torch.cuda.is_available())") - ROCM_AVAILABLE=$(python -c "import torch;print(torch.version.hip is not None)") - if [[ "$(uname -s)" != "Darwin" && "${CUDA_AVAILABLE}" == "True" ]]; then - CONTAINER_FLAVOR="cuda" - elif [[ "$(uname -s)" != "Darwin" && "${ROCM_AVAILABLE}" == "True" ]]; then - CONTAINER_FLAVOR="rocm" - else - CONTAINER_FLAVOR="cpu" +if python -c "import torch" &>/dev/null; then + if [[ -z "$PIP_EXTRA_INDEX_URL" ]]; then + # Decide which container flavor to build if not specified + if [[ -z "$CONTAINER_FLAVOR" ]]; then + # Check for CUDA and ROCm + CUDA_AVAILABLE=$(python -c "import torch;print(torch.cuda.is_available())") + ROCM_AVAILABLE=$(python -c "import torch;print(torch.version.hip is not None)") + if [[ "$(uname -s)" != "Darwin" && "${CUDA_AVAILABLE}" == "True" ]]; then + CONTAINER_FLAVOR="cuda" + elif [[ "$(uname -s)" != "Darwin" && "${ROCM_AVAILABLE}" == "True" ]]; then + CONTAINER_FLAVOR="rocm" + else + CONTAINER_FLAVOR="cpu" + fi + fi + # Set PIP_EXTRA_INDEX_URL based on container flavor + if [[ "$CONTAINER_FLAVOR" == "rocm" ]]; then + PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm" + elif [[ "$CONTAINER_FLAVOR" == "cpu" ]]; then + PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" fi - fi - # Set PIP_EXTRA_INDEX_URL based on container flavor - if [[ "$CONTAINER_FLAVOR" == "rocm" ]]; then - PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/rocm" - elif [[ "$CONTAINER_FLAVOR" == "cpu" ]]; then - PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" fi fi @@ -30,6 +32,7 @@ PLATFORM="${PLATFORM-Linux/${ARCH}}" INVOKEAI_BRANCH="${INVOKEAI_BRANCH-$(git branch --show)}" CONTAINER_REGISTRY="${CONTAINER_REGISTRY-"ghcr.io"}" CONTAINER_REPOSITORY="${CONTAINER_REPOSITORY-"$(whoami)/${REPOSITORY_NAME}"}" +CONTAINER_FLAVOR="${CONTAINER_FLAVOR-cuda}" CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}" CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}" CONTAINER_IMAGE="${CONTAINER_IMAGE,,}"