mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: work more better with more models.
fixed relative path problem with local models. fixed models on hub not always having a `fp16` branch.
This commit is contained in:
parent
50c48cffc7
commit
66d32b79b7
@ -21,6 +21,8 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.utils import RevisionNotFoundError
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
@ -323,6 +325,8 @@ class ModelCache(object):
|
|||||||
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
|
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
|
||||||
elif 'path' in mconfig:
|
elif 'path' in mconfig:
|
||||||
name_or_path = Path(mconfig['path'])
|
name_or_path = Path(mconfig['path'])
|
||||||
|
if not name_or_path.is_absolute():
|
||||||
|
name_or_path = Path(Globals.root, name_or_path).resolve()
|
||||||
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
|
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
|
||||||
# the submodels hashed together? The commit ID from the repo?
|
# the submodels hashed together? The commit ID from the repo?
|
||||||
model_hash = "FIXME TOO"
|
model_hash = "FIXME TOO"
|
||||||
@ -335,7 +339,16 @@ class ModelCache(object):
|
|||||||
|
|
||||||
if self.precision == 'float16':
|
if self.precision == 'float16':
|
||||||
print(' | Using faster float16 precision')
|
print(' | Using faster float16 precision')
|
||||||
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
|
|
||||||
|
if not isinstance(name_or_path, Path):
|
||||||
|
try:
|
||||||
|
hf_hub_download(name_or_path, "model_index.json", revision="fp16")
|
||||||
|
except RevisionNotFoundError as e:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pipeline_args.update(revision="fp16")
|
||||||
|
|
||||||
|
pipeline_args.update(torch_dtype=torch.float16)
|
||||||
else:
|
else:
|
||||||
# TODO: more accurately, "using the model's default precision."
|
# TODO: more accurately, "using the model's default precision."
|
||||||
# How do we find out what that is?
|
# How do we find out what that is?
|
||||||
@ -363,7 +376,10 @@ class ModelCache(object):
|
|||||||
if 'repo_name' in mconfig:
|
if 'repo_name' in mconfig:
|
||||||
return mconfig['repo_name']
|
return mconfig['repo_name']
|
||||||
elif 'path' in mconfig:
|
elif 'path' in mconfig:
|
||||||
return Path(mconfig['path'])
|
path = Path(mconfig['path'])
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = Path(Globals.root, path).resolve()
|
||||||
|
return path
|
||||||
else:
|
else:
|
||||||
raise ValueError("Model config must specify either repo_name or path.")
|
raise ValueError("Model config must specify either repo_name or path.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user