Merge branch 'main' into lstein-import-safetensors

This commit is contained in:
Lincoln Stein 2023-01-18 17:09:48 -05:00
commit e11f15cf78
2 changed files with 38 additions and 16 deletions

View File

@ -62,11 +62,21 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
''' '''
Returns Path to the model cache directory. If a subdirectory Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing is provided, it will be appended to the end of the path, allowing
for huggingface-style conventions: for huggingface-style conventions:
global_cache_dir('diffusers') global_cache_dir('diffusers')
global_cache_dir('transformers') global_cache_dir('transformers')
''' '''
if (home := os.environ.get('HF_HOME')): home: str = os.getenv('HF_HOME')
if home is None:
home = os.getenv('XDG_CACHE_HOME')
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + 'huggingface'
if home is not None:
return Path(home,subdir) return Path(home,subdir)
else: else:
return Path(Globals.root,'models',subdir) return Path(Globals.root,'models',subdir)

View File

@ -166,7 +166,7 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style) # don't include VAEs in listing (legacy style)
if 'config' in stanza and '/VAE/' in stanza['config']: if 'config' in stanza and '/VAE/' in stanza['config']:
continue continue
models[name] = dict() models[name] = dict()
format = stanza.get('format','ckpt') # Determine Format format = stanza.get('format','ckpt') # Determine Format
@ -183,7 +183,7 @@ class ModelManager(object):
format = format, format = format,
status = status, status = status,
) )
# Checkpoint Config Parse # Checkpoint Config Parse
if format == 'ckpt': if format == 'ckpt':
models[name].update( models[name].update(
@ -193,7 +193,7 @@ class ModelManager(object):
width = str(stanza.get('width', 512)), width = str(stanza.get('width', 512)),
height = str(stanza.get('height', 512)), height = str(stanza.get('height', 512)),
) )
# Diffusers Config Parse # Diffusers Config Parse
if (vae := stanza.get('vae',None)): if (vae := stanza.get('vae',None)):
if isinstance(vae,DictConfig): if isinstance(vae,DictConfig):
@ -202,14 +202,14 @@ class ModelManager(object):
path = str(vae.get('path',None)), path = str(vae.get('path',None)),
subfolder = str(vae.get('subfolder',None)) subfolder = str(vae.get('subfolder',None))
) )
if format == 'diffusers': if format == 'diffusers':
models[name].update( models[name].update(
vae = vae, vae = vae,
repo_id = str(stanza.get('repo_id', None)), repo_id = str(stanza.get('repo_id', None)),
path = str(stanza.get('path',None)), path = str(stanza.get('path',None)),
) )
return models return models
def print_models(self) -> None: def print_models(self) -> None:
@ -257,7 +257,7 @@ class ModelManager(object):
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"' assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
omega[model_name] = model_attributes omega[model_name] = model_attributes
if 'weights' in omega[model_name]: if 'weights' in omega[model_name]:
omega[model_name]['weights'].replace('\\','/') omega[model_name]['weights'].replace('\\','/')
@ -560,12 +560,12 @@ class ModelManager(object):
''' '''
Attempts to install the indicated ckpt file and returns True if successful. Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file "weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model. or a http/https URL pointing to a remote model.
"config" is the model config file to use with this ckpt file. It defaults to "config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded. v1-inference.yaml. If a URL is provided, the config will be downloaded.
You can optionally provide a model name and/or description. If not provided, You can optionally provide a model name and/or description. If not provided,
then these will be derived from the weight file name. If you provide a commit_to_conf then these will be derived from the weight file name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the path to the configuration file, then the new entry will be committed to the
@ -578,7 +578,7 @@ class ModelManager(object):
return False return False
if config_path is None or not config_path.exists(): if config_path is None or not config_path.exists():
return False return False
model_name = model_name or Path(weights).stem model_name = model_name or Path(weights).stem
model_description = model_description or f'imported stable diffusion weights file {model_name}' model_description = model_description or f'imported stable diffusion weights file {model_name}'
new_config = dict( new_config = dict(
@ -593,7 +593,7 @@ class ModelManager(object):
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
return True return True
def autoconvert_weights( def autoconvert_weights(
self, self,
conf_path:Path, conf_path:Path,
@ -666,7 +666,7 @@ class ModelManager(object):
except Exception as e: except Exception as e:
print(f'** Conversion failed: {str(e)}') print(f'** Conversion failed: {str(e)}')
traceback.print_exc() traceback.print_exc()
print('done.') print('done.')
return new_config return new_config
@ -762,9 +762,13 @@ class ModelManager(object):
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.') print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
print('** This is a quick one-time operation.') print('** This is a quick one-time operation.')
from shutil import move, rmtree from shutil import move, rmtree
# transformer files get moved into the hub directory # transformer files get moved into the hub directory
hub = models_dir / 'hub' if cls._is_huggingface_hub_directory_present():
hub = global_cache_dir('hub')
else:
hub = models_dir / 'hub'
os.makedirs(hub, exist_ok=True) os.makedirs(hub, exist_ok=True)
for model in legacy_locations: for model in legacy_locations:
source = models_dir / model source = models_dir / model
@ -777,7 +781,11 @@ class ModelManager(object):
move(source, dest) move(source, dest)
# anything else gets moved into the diffusers directory # anything else gets moved into the diffusers directory
diffusers = models_dir / 'diffusers' if cls._is_huggingface_hub_directory_present():
diffusers = global_cache_dir('diffusers')
else:
diffusers = models_dir / 'diffusers'
os.makedirs(diffusers, exist_ok=True) os.makedirs(diffusers, exist_ok=True)
for root, dirs, _ in os.walk(models_dir, topdown=False): for root, dirs, _ in os.walk(models_dir, topdown=False):
for dir in dirs: for dir in dirs:
@ -968,3 +976,7 @@ class ModelManager(object):
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}') print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
return vae return vae
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None