Migrate legacy models (pre-2.3.0) to 🤗 cache directory if exists

This commit is contained in:
Daya Adianto
2023-01-18 21:02:31 +07:00
parent 2282e681f7
commit aa4e8d8cf3

View File

@ -166,7 +166,7 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style) # don't include VAEs in listing (legacy style)
if 'config' in stanza and '/VAE/' in stanza['config']: if 'config' in stanza and '/VAE/' in stanza['config']:
continue continue
models[name] = dict() models[name] = dict()
format = stanza.get('format','ckpt') # Determine Format format = stanza.get('format','ckpt') # Determine Format
@ -183,7 +183,7 @@ class ModelManager(object):
format = format, format = format,
status = status, status = status,
) )
# Checkpoint Config Parse # Checkpoint Config Parse
if format == 'ckpt': if format == 'ckpt':
models[name].update( models[name].update(
@ -193,7 +193,7 @@ class ModelManager(object):
width = str(stanza.get('width', 512)), width = str(stanza.get('width', 512)),
height = str(stanza.get('height', 512)), height = str(stanza.get('height', 512)),
) )
# Diffusers Config Parse # Diffusers Config Parse
if (vae := stanza.get('vae',None)): if (vae := stanza.get('vae',None)):
if isinstance(vae,DictConfig): if isinstance(vae,DictConfig):
@ -202,14 +202,14 @@ class ModelManager(object):
path = str(vae.get('path',None)), path = str(vae.get('path',None)),
subfolder = str(vae.get('subfolder',None)) subfolder = str(vae.get('subfolder',None))
) )
if format == 'diffusers': if format == 'diffusers':
models[name].update( models[name].update(
vae = vae, vae = vae,
repo_id = str(stanza.get('repo_id', None)), repo_id = str(stanza.get('repo_id', None)),
path = str(stanza.get('path',None)), path = str(stanza.get('path',None)),
) )
return models return models
def print_models(self) -> None: def print_models(self) -> None:
@ -257,7 +257,7 @@ class ModelManager(object):
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"' assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
omega[model_name] = model_attributes omega[model_name] = model_attributes
if 'weights' in omega[model_name]: if 'weights' in omega[model_name]:
omega[model_name]['weights'].replace('\\','/') omega[model_name]['weights'].replace('\\','/')
@ -554,12 +554,12 @@ class ModelManager(object):
''' '''
Attempts to install the indicated ckpt file and returns True if successful. Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file "weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model. or a http/https URL pointing to a remote model.
"config" is the model config file to use with this ckpt file. It defaults to "config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded. v1-inference.yaml. If a URL is provided, the config will be downloaded.
You can optionally provide a model name and/or description. If not provided, You can optionally provide a model name and/or description. If not provided,
then these will be derived from the weight file name. If you provide a commit_to_conf then these will be derived from the weight file name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the path to the configuration file, then the new entry will be committed to the
@ -572,7 +572,7 @@ class ModelManager(object):
return False return False
if config_path is None or not config_path.exists(): if config_path is None or not config_path.exists():
return False return False
model_name = model_name or Path(weights).stem model_name = model_name or Path(weights).stem
model_description = model_description or f'imported stable diffusion weights file {model_name}' model_description = model_description or f'imported stable diffusion weights file {model_name}'
new_config = dict( new_config = dict(
@ -587,7 +587,7 @@ class ModelManager(object):
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
return True return True
def autoconvert_weights( def autoconvert_weights(
self, self,
conf_path:Path, conf_path:Path,
@ -660,7 +660,7 @@ class ModelManager(object):
except Exception as e: except Exception as e:
print(f'** Conversion failed: {str(e)}') print(f'** Conversion failed: {str(e)}')
traceback.print_exc() traceback.print_exc()
print('done.') print('done.')
return new_config return new_config
@ -756,9 +756,13 @@ class ModelManager(object):
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.') print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
print('** This is a quick one-time operation.') print('** This is a quick one-time operation.')
from shutil import move, rmtree from shutil import move, rmtree
# transformer files get moved into the hub directory # transformer files get moved into the hub directory
hub = models_dir / 'hub' if cls._is_huggingface_hub_directory_present():
hub = global_cache_dir() / 'hub'
else:
hub = models_dir / 'hub'
os.makedirs(hub, exist_ok=True) os.makedirs(hub, exist_ok=True)
for model in legacy_locations: for model in legacy_locations:
source = models_dir / model source = models_dir / model
@ -771,7 +775,11 @@ class ModelManager(object):
move(source, dest) move(source, dest)
# anything else gets moved into the diffusers directory # anything else gets moved into the diffusers directory
diffusers = models_dir / 'diffusers' if cls._is_huggingface_hub_directory_present():
diffusers = global_cache_dir() / 'diffusers'
else:
diffusers = models_dir / 'diffusers'
os.makedirs(diffusers, exist_ok=True) os.makedirs(diffusers, exist_ok=True)
for root, dirs, _ in os.walk(models_dir, topdown=False): for root, dirs, _ in os.walk(models_dir, topdown=False):
for dir in dirs: for dir in dirs:
@ -962,3 +970,7 @@ class ModelManager(object):
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}') print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
return vae return vae
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None