From 66d32b79b736ec8686523ea7b9aacba924bf0673 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 10 Dec 2022 08:29:12 -0800 Subject: [PATCH] diffusers: work more better with more models. fixed relative path problem with local models. fixed models on hub not always having a `fp16` branch. --- ldm/invoke/model_cache.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index ee32ba12e5..f57dc6de19 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -21,6 +21,8 @@ from typing import Union import torch import transformers +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import RevisionNotFoundError from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError from picklescan.scanner import scan_file_path @@ -323,6 +325,8 @@ class ModelCache(object): # model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash elif 'path' in mconfig: name_or_path = Path(mconfig['path']) + if not name_or_path.is_absolute(): + name_or_path = Path(Globals.root, name_or_path).resolve() # FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all # the submodels hashed together? The commit ID from the repo? model_hash = "FIXME TOO" @@ -335,7 +339,16 @@ class ModelCache(object): if self.precision == 'float16': print(' | Using faster float16 precision') - pipeline_args.update(revision="fp16", torch_dtype=torch.float16) + + if not isinstance(name_or_path, Path): + try: + hf_hub_download(name_or_path, "model_index.json", revision="fp16") + except RevisionNotFoundError as e: + pass + else: + pipeline_args.update(revision="fp16") + + pipeline_args.update(torch_dtype=torch.float16) else: # TODO: more accurately, "using the model's default precision." # How do we find out what that is? @@ -363,7 +376,10 @@ class ModelCache(object): if 'repo_name' in mconfig: return mconfig['repo_name'] elif 'path' in mconfig: - return Path(mconfig['path']) + path = Path(mconfig['path']) + if not path.is_absolute(): + path = Path(Globals.root, path).resolve() + return path else: raise ValueError("Model config must specify either repo_name or path.")