Merge branch 'main' into lstein-import-safetensors

This commit is contained in:
Lincoln Stein 2023-01-18 17:09:48 -05:00
commit e11f15cf78
2 changed files with 38 additions and 16 deletions

View File

@ -62,11 +62,21 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
'''
Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing
for huggingface-style conventions:
for huggingface-style conventions:
global_cache_dir('diffusers')
global_cache_dir('transformers')
'''
if (home := os.environ.get('HF_HOME')):
home: str = os.getenv('HF_HOME')
if home is None:
home = os.getenv('XDG_CACHE_HOME')
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + 'huggingface'
if home is not None:
return Path(home,subdir)
else:
return Path(Globals.root,'models',subdir)

View File

@ -166,7 +166,7 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style)
if 'config' in stanza and '/VAE/' in stanza['config']:
continue
models[name] = dict()
format = stanza.get('format','ckpt') # Determine Format
@ -183,7 +183,7 @@ class ModelManager(object):
format = format,
status = status,
)
# Checkpoint Config Parse
if format == 'ckpt':
models[name].update(
@ -193,7 +193,7 @@ class ModelManager(object):
width = str(stanza.get('width', 512)),
height = str(stanza.get('height', 512)),
)
# Diffusers Config Parse
if (vae := stanza.get('vae',None)):
if isinstance(vae,DictConfig):
@ -202,14 +202,14 @@ class ModelManager(object):
path = str(vae.get('path',None)),
subfolder = str(vae.get('subfolder',None))
)
if format == 'diffusers':
models[name].update(
vae = vae,
repo_id = str(stanza.get('repo_id', None)),
path = str(stanza.get('path',None)),
)
return models
def print_models(self) -> None:
@ -257,7 +257,7 @@ class ModelManager(object):
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
omega[model_name] = model_attributes
if 'weights' in omega[model_name]:
omega[model_name]['weights'].replace('\\','/')
@ -560,12 +560,12 @@ class ModelManager(object):
'''
Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file
"weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model.
"config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded.
You can optionally provide a model name and/or description. If not provided,
then these will be derived from the weight file name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the
@ -578,7 +578,7 @@ class ModelManager(object):
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_description = model_description or f'imported stable diffusion weights file {model_name}'
new_config = dict(
@ -593,7 +593,7 @@ class ModelManager(object):
if commit_to_conf:
self.commit(commit_to_conf)
return True
def autoconvert_weights(
self,
conf_path:Path,
@ -666,7 +666,7 @@ class ModelManager(object):
except Exception as e:
print(f'** Conversion failed: {str(e)}')
traceback.print_exc()
print('done.')
return new_config
@ -762,9 +762,13 @@ class ModelManager(object):
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
print('** This is a quick one-time operation.')
from shutil import move, rmtree
# transformer files get moved into the hub directory
hub = models_dir / 'hub'
if cls._is_huggingface_hub_directory_present():
hub = global_cache_dir('hub')
else:
hub = models_dir / 'hub'
os.makedirs(hub, exist_ok=True)
for model in legacy_locations:
source = models_dir / model
@ -777,7 +781,11 @@ class ModelManager(object):
move(source, dest)
# anything else gets moved into the diffusers directory
diffusers = models_dir / 'diffusers'
if cls._is_huggingface_hub_directory_present():
diffusers = global_cache_dir('diffusers')
else:
diffusers = models_dir / 'diffusers'
os.makedirs(diffusers, exist_ok=True)
for root, dirs, _ in os.walk(models_dir, topdown=False):
for dir in dirs:
@ -968,3 +976,7 @@ class ModelManager(object):
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
return vae
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None