Installer should download fp16 models if user has specified 'auto' in config (#4129)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

At install time, when the user's config specified "auto" precision, the
installer was downloading the fp32 models even when an fp16 model would
be appropriate for the OS.


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Closes #4127
This commit is contained in:
StAlKeR7779 2023-08-05 01:56:25 +03:00 committed by GitHub
commit 3d93851dba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,6 +13,7 @@ import requests
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import onnx
import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf
from tqdm import tqdm
@ -23,6 +24,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
@ -416,15 +418,25 @@ class ModelInstall(object):
does a save_pretrained() to the indicated staging area.
"""
_, name = repo_id.split("/")
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
precision = torch_dtype(choose_torch_device())
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
model = None
for revision in revisions:
for variant in variants:
try:
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
except: # most errors are due to fp16 not being present. Fix this to catch other errors
pass
model = DiffusionPipeline.from_pretrained(
repo_id,
variant=variant,
torch_dtype=precision,
safety_checker=None,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model:
break
if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None