diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index ac032d4955..fa640719d0 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -13,6 +13,7 @@ import requests from diffusers import DiffusionPipeline from diffusers import logging as dlogging import onnx +import torch from huggingface_hub import hf_hub_url, HfFolder, HfApi from omegaconf import OmegaConf from tqdm import tqdm @@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.util import download_with_resume +from invokeai.backend.util.devices import torch_dtype, choose_torch_device from ..util.logging import InvokeAILogger warnings.filterwarnings("ignore") @@ -416,15 +418,25 @@ class ModelInstall(object): does a save_pretrained() to the indicated staging area. """ _, name = repo_id.split("/") - revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"] + precision = torch_dtype(choose_torch_device()) + variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] + model = None - for revision in revisions: + for variant in variants: try: - model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None) - except: # most errors are due to fp16 not being present. Fix this to catch other errors - pass + model = DiffusionPipeline.from_pretrained( + repo_id, + variant=variant, + torch_dtype=precision, + safety_checker=None, + ) + except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors + if "fp16" not in str(e): + print(e) + if model: break + if not model: logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") return None