mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: work more better with more models.
fixed relative path problem with local models. fixed models on hub not always having a `fp16` branch.
This commit is contained in:
parent
50c48cffc7
commit
66d32b79b7
@ -21,6 +21,8 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import RevisionNotFoundError
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from picklescan.scanner import scan_file_path
|
||||
@ -323,6 +325,8 @@ class ModelCache(object):
|
||||
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
|
||||
elif 'path' in mconfig:
|
||||
name_or_path = Path(mconfig['path'])
|
||||
if not name_or_path.is_absolute():
|
||||
name_or_path = Path(Globals.root, name_or_path).resolve()
|
||||
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
|
||||
# the submodels hashed together? The commit ID from the repo?
|
||||
model_hash = "FIXME TOO"
|
||||
@ -335,7 +339,16 @@ class ModelCache(object):
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
|
||||
|
||||
if not isinstance(name_or_path, Path):
|
||||
try:
|
||||
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
||||
except RevisionNotFoundError as e:
|
||||
pass
|
||||
else:
|
||||
pipeline_args.update(revision="fp16")
|
||||
|
||||
pipeline_args.update(torch_dtype=torch.float16)
|
||||
else:
|
||||
# TODO: more accurately, "using the model's default precision."
|
||||
# How do we find out what that is?
|
||||
@ -363,7 +376,10 @@ class ModelCache(object):
|
||||
if 'repo_name' in mconfig:
|
||||
return mconfig['repo_name']
|
||||
elif 'path' in mconfig:
|
||||
return Path(mconfig['path'])
|
||||
path = Path(mconfig['path'])
|
||||
if not path.is_absolute():
|
||||
path = Path(Globals.root, path).resolve()
|
||||
return path
|
||||
else:
|
||||
raise ValueError("Model config must specify either repo_name or path.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user