diffusers: work more better with more models.

fixed relative path problem with local models.

fixed models on hub not always having a `fp16` branch.
This commit is contained in:
Kevin Turner 2022-12-10 08:29:12 -08:00
parent 50c48cffc7
commit 66d32b79b7

View File

@ -21,6 +21,8 @@ from typing import Union
import torch
import transformers
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import RevisionNotFoundError
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
from picklescan.scanner import scan_file_path
@ -323,6 +325,8 @@ class ModelCache(object):
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
elif 'path' in mconfig:
name_or_path = Path(mconfig['path'])
if not name_or_path.is_absolute():
name_or_path = Path(Globals.root, name_or_path).resolve()
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
# the submodels hashed together? The commit ID from the repo?
model_hash = "FIXME TOO"
@ -335,7 +339,16 @@ class ModelCache(object):
if self.precision == 'float16':
print(' | Using faster float16 precision')
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
if not isinstance(name_or_path, Path):
try:
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
except RevisionNotFoundError as e:
pass
else:
pipeline_args.update(revision="fp16")
pipeline_args.update(torch_dtype=torch.float16)
else:
# TODO: more accurately, "using the model's default precision."
# How do we find out what that is?
@ -363,7 +376,10 @@ class ModelCache(object):
if 'repo_name' in mconfig:
return mconfig['repo_name']
elif 'path' in mconfig:
return Path(mconfig['path'])
path = Path(mconfig['path'])
if not path.is_absolute():
path = Path(Globals.root, path).resolve()
return path
else:
raise ValueError("Model config must specify either repo_name or path.")