mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein-import-safetensors
This commit is contained in:
commit
e11f15cf78
@ -62,11 +62,21 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
|||||||
'''
|
'''
|
||||||
Returns Path to the model cache directory. If a subdirectory
|
Returns Path to the model cache directory. If a subdirectory
|
||||||
is provided, it will be appended to the end of the path, allowing
|
is provided, it will be appended to the end of the path, allowing
|
||||||
for huggingface-style conventions:
|
for huggingface-style conventions:
|
||||||
global_cache_dir('diffusers')
|
global_cache_dir('diffusers')
|
||||||
global_cache_dir('transformers')
|
global_cache_dir('transformers')
|
||||||
'''
|
'''
|
||||||
if (home := os.environ.get('HF_HOME')):
|
home: str = os.getenv('HF_HOME')
|
||||||
|
|
||||||
|
if home is None:
|
||||||
|
home = os.getenv('XDG_CACHE_HOME')
|
||||||
|
|
||||||
|
if home is not None:
|
||||||
|
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in HuggingFace Hub Client Library.
|
||||||
|
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
||||||
|
home += os.sep + 'huggingface'
|
||||||
|
|
||||||
|
if home is not None:
|
||||||
return Path(home,subdir)
|
return Path(home,subdir)
|
||||||
else:
|
else:
|
||||||
return Path(Globals.root,'models',subdir)
|
return Path(Globals.root,'models',subdir)
|
||||||
|
@ -166,7 +166,7 @@ class ModelManager(object):
|
|||||||
# don't include VAEs in listing (legacy style)
|
# don't include VAEs in listing (legacy style)
|
||||||
if 'config' in stanza and '/VAE/' in stanza['config']:
|
if 'config' in stanza and '/VAE/' in stanza['config']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
models[name] = dict()
|
models[name] = dict()
|
||||||
format = stanza.get('format','ckpt') # Determine Format
|
format = stanza.get('format','ckpt') # Determine Format
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ class ModelManager(object):
|
|||||||
format = format,
|
format = format,
|
||||||
status = status,
|
status = status,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoint Config Parse
|
# Checkpoint Config Parse
|
||||||
if format == 'ckpt':
|
if format == 'ckpt':
|
||||||
models[name].update(
|
models[name].update(
|
||||||
@ -193,7 +193,7 @@ class ModelManager(object):
|
|||||||
width = str(stanza.get('width', 512)),
|
width = str(stanza.get('width', 512)),
|
||||||
height = str(stanza.get('height', 512)),
|
height = str(stanza.get('height', 512)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Diffusers Config Parse
|
# Diffusers Config Parse
|
||||||
if (vae := stanza.get('vae',None)):
|
if (vae := stanza.get('vae',None)):
|
||||||
if isinstance(vae,DictConfig):
|
if isinstance(vae,DictConfig):
|
||||||
@ -202,14 +202,14 @@ class ModelManager(object):
|
|||||||
path = str(vae.get('path',None)),
|
path = str(vae.get('path',None)),
|
||||||
subfolder = str(vae.get('subfolder',None))
|
subfolder = str(vae.get('subfolder',None))
|
||||||
)
|
)
|
||||||
|
|
||||||
if format == 'diffusers':
|
if format == 'diffusers':
|
||||||
models[name].update(
|
models[name].update(
|
||||||
vae = vae,
|
vae = vae,
|
||||||
repo_id = str(stanza.get('repo_id', None)),
|
repo_id = str(stanza.get('repo_id', None)),
|
||||||
path = str(stanza.get('path',None)),
|
path = str(stanza.get('path',None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
@ -257,7 +257,7 @@ class ModelManager(object):
|
|||||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||||
|
|
||||||
omega[model_name] = model_attributes
|
omega[model_name] = model_attributes
|
||||||
|
|
||||||
if 'weights' in omega[model_name]:
|
if 'weights' in omega[model_name]:
|
||||||
omega[model_name]['weights'].replace('\\','/')
|
omega[model_name]['weights'].replace('\\','/')
|
||||||
|
|
||||||
@ -560,12 +560,12 @@ class ModelManager(object):
|
|||||||
'''
|
'''
|
||||||
Attempts to install the indicated ckpt file and returns True if successful.
|
Attempts to install the indicated ckpt file and returns True if successful.
|
||||||
|
|
||||||
"weights" can be either a path-like object corresponding to a local .ckpt file
|
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||||
or a http/https URL pointing to a remote model.
|
or a http/https URL pointing to a remote model.
|
||||||
|
|
||||||
"config" is the model config file to use with this ckpt file. It defaults to
|
"config" is the model config file to use with this ckpt file. It defaults to
|
||||||
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||||
|
|
||||||
You can optionally provide a model name and/or description. If not provided,
|
You can optionally provide a model name and/or description. If not provided,
|
||||||
then these will be derived from the weight file name. If you provide a commit_to_conf
|
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||||
path to the configuration file, then the new entry will be committed to the
|
path to the configuration file, then the new entry will be committed to the
|
||||||
@ -578,7 +578,7 @@ class ModelManager(object):
|
|||||||
return False
|
return False
|
||||||
if config_path is None or not config_path.exists():
|
if config_path is None or not config_path.exists():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model_name = model_name or Path(weights).stem
|
model_name = model_name or Path(weights).stem
|
||||||
model_description = model_description or f'imported stable diffusion weights file {model_name}'
|
model_description = model_description or f'imported stable diffusion weights file {model_name}'
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
@ -593,7 +593,7 @@ class ModelManager(object):
|
|||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def autoconvert_weights(
|
def autoconvert_weights(
|
||||||
self,
|
self,
|
||||||
conf_path:Path,
|
conf_path:Path,
|
||||||
@ -666,7 +666,7 @@ class ModelManager(object):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** Conversion failed: {str(e)}')
|
print(f'** Conversion failed: {str(e)}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
print('done.')
|
print('done.')
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
@ -762,9 +762,13 @@ class ModelManager(object):
|
|||||||
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
|
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
|
||||||
print('** This is a quick one-time operation.')
|
print('** This is a quick one-time operation.')
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
|
|
||||||
# transformer files get moved into the hub directory
|
# transformer files get moved into the hub directory
|
||||||
hub = models_dir / 'hub'
|
if cls._is_huggingface_hub_directory_present():
|
||||||
|
hub = global_cache_dir('hub')
|
||||||
|
else:
|
||||||
|
hub = models_dir / 'hub'
|
||||||
|
|
||||||
os.makedirs(hub, exist_ok=True)
|
os.makedirs(hub, exist_ok=True)
|
||||||
for model in legacy_locations:
|
for model in legacy_locations:
|
||||||
source = models_dir / model
|
source = models_dir / model
|
||||||
@ -777,7 +781,11 @@ class ModelManager(object):
|
|||||||
move(source, dest)
|
move(source, dest)
|
||||||
|
|
||||||
# anything else gets moved into the diffusers directory
|
# anything else gets moved into the diffusers directory
|
||||||
diffusers = models_dir / 'diffusers'
|
if cls._is_huggingface_hub_directory_present():
|
||||||
|
diffusers = global_cache_dir('diffusers')
|
||||||
|
else:
|
||||||
|
diffusers = models_dir / 'diffusers'
|
||||||
|
|
||||||
os.makedirs(diffusers, exist_ok=True)
|
os.makedirs(diffusers, exist_ok=True)
|
||||||
for root, dirs, _ in os.walk(models_dir, topdown=False):
|
for root, dirs, _ in os.walk(models_dir, topdown=False):
|
||||||
for dir in dirs:
|
for dir in dirs:
|
||||||
@ -968,3 +976,7 @@ class ModelManager(object):
|
|||||||
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
|
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
|
||||||
|
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_huggingface_hub_directory_present() -> bool:
|
||||||
|
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user