Installer should download fp16 models if user has specified 'auto' in config

- Closes #4127
This commit is contained in:
Lincoln Stein 2023-08-01 22:06:27 -04:00
parent 4599575e65
commit 4d22cafdad

View File

@ -13,6 +13,7 @@ import requests
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
import onnx import onnx
import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -416,13 +418,17 @@ class ModelInstall(object):
does a save_pretrained() to the indicated staging area. does a save_pretrained() to the indicated staging area.
""" """
_, name = repo_id.split("/") _, name = repo_id.split("/")
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"] precision = torch_dtype(choose_torch_device())
revisions = ["fp16", "main"] if precision == torch.float16 else ["main"]
model = None model = None
for revision in revisions: for revision in revisions:
try: try:
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None) model = DiffusionPipeline.from_pretrained(
except: # most errors are due to fp16 not being present. Fix this to catch other errors repo_id, revision=revision, safety_checker=None, torch_dtype=precision
pass )
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model: if model:
break break
if not model: if not model: