fix two bugs in model import

1. !import_model did not allow user to specify VAE file. This is now fixed.
2. !del_model did not offer the user the opportunity to delete the underlying
   weights file or diffusers directory. This is now fixed.
This commit is contained in:
Lincoln Stein 2023-01-19 01:30:58 -05:00
parent e11f15cf78
commit 7bd2220a24
2 changed files with 84 additions and 33 deletions

View File

@ -124,7 +124,7 @@ def main():
# preload the model # preload the model
try: try:
gen.load_model() gen.load_model()
except KeyError as e: except KeyError:
pass pass
except Exception as e: except Exception as e:
report_model_error(opt, e) report_model_error(opt, e)
@ -589,7 +589,7 @@ def import_model(model_path:str, gen, opt, completer):
gen.model_manager.del_model(model_name) gen.model_manager.del_model(model_name)
return return
if input('Make this the default model? [n] ') in ('y','Y'): if input('Make this the default model? [n] ').strip() in ('y','Y'):
gen.model_manager.set_default_model(model_name) gen.model_manager.set_default_model(model_name)
gen.model_manager.commit(opt.conf) gen.model_manager.commit(opt.conf)
@ -606,10 +606,14 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
model_name=default_name, model_name=default_name,
model_description=default_description model_description=default_description
) )
vae = None
if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-se"? [n] ').strip() in ('y','Y'):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
if not manager.import_diffuser_model( if not manager.import_diffuser_model(
path_or_repo, path_or_repo,
model_name = model_name, model_name = model_name,
vae = vae,
description = model_description): description = model_description):
print('** model failed to import') print('** model failed to import')
return None return None
@ -627,17 +631,28 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
) )
config_file = None config_file = None
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
completer.complete_extensions(('.yaml','.yml')) completer.complete_extensions(('.yaml','.yml'))
completer.set_line(str(default)) completer.set_line(str(default))
done = False done = False
while not done: while not done:
config_file = input('Configuration file for this model: ').strip() config_file = input('Configuration file for this model: ').strip()
done = os.path.exists(config_file) done = os.path.exists(config_file)
completer.complete_extensions(('.ckpt','.safetensors'))
vae = None
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
completer.set_line(str(default))
done = False
while not done:
vae = input('VAE file for this model (leave blank for none): ').strip() or None
done = (not vae) or os.path.exists(vae)
completer.complete_extensions(None) completer.complete_extensions(None)
if not manager.import_ckpt_model( if not manager.import_ckpt_model(
path_or_url, path_or_url,
config = config_file, config = config_file,
vae = vae,
model_name = model_name, model_name = model_name,
model_description = model_description, model_description = model_description,
commit_to_conf = opt.conf, commit_to_conf = opt.conf,
@ -709,7 +724,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
return return
completer.update_models(gen.model_manager.list_models()) completer.update_models(gen.model_manager.list_models())
if input(f'Load optimized model {model_name}? [y] ') not in ('n','N'): if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'):
gen.set_model(model_name) gen.set_model(model_name)
response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ') response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
@ -725,17 +740,17 @@ def del_config(model_name:str, gen, opt, completer):
if model_name not in gen.model_manager.config: if model_name not in gen.model_manager.config:
print(f"** Unknown model {model_name}") print(f"** Unknown model {model_name}")
return return
gen.model_manager.del_model(model_name)
if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')):
return
delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y'))
gen.model_manager.del_model(model_name,delete_files=delete_completely)
gen.model_manager.commit(opt.conf) gen.model_manager.commit(opt.conf)
print(f'** {model_name} deleted') print(f'** {model_name} deleted')
completer.update_models(gen.model_manager.list_models()) completer.update_models(gen.model_manager.list_models())
def edit_model(model_name:str, gen, opt, completer): def edit_model(model_name:str, gen, opt, completer):
current_model = gen.model_name
# if model_name == current_model:
# print("** Can't edit the active model. !switch to another model first. **")
# return
manager = gen.model_manager manager = gen.model_manager
if not (info := manager.model_info(model_name)): if not (info := manager.model_info(model_name)):
print(f'** Unknown model {model_name}') print(f'** Unknown model {model_name}')

View File

@ -18,7 +18,9 @@ import traceback
import warnings import warnings
import safetensors.torch import safetensors.torch
from pathlib import Path from pathlib import Path
from shutil import move, rmtree
from typing import Union, Any from typing import Union, Any
from huggingface_hub import scan_cache_dir
from ldm.util import download_with_progress_bar from ldm.util import download_with_progress_bar
import torch import torch
@ -225,7 +227,7 @@ class ModelManager(object):
line = f'\033[1m{line}\033[0m' line = f'\033[1m{line}\033[0m'
print(line) print(line)
def del_model(self, model_name:str) -> None: def del_model(self, model_name:str, delete_files:bool=False) -> None:
''' '''
Delete the named model. Delete the named model.
''' '''
@ -233,9 +235,25 @@ class ModelManager(object):
if model_name not in omega: if model_name not in omega:
print(f'** Unknown model {model_name}') print(f'** Unknown model {model_name}')
return return
# save these for use in deletion later
conf = omega[model_name]
repo_id = conf.get('repo_id',None)
path = self._relativize(conf.get('path',None))
weights = self._relativize(conf.get('weights',None))
del omega[model_name] del omega[model_name]
if model_name in self.stack: if model_name in self.stack:
self.stack.remove(model_name) self.stack.remove(model_name)
if delete_files:
if weights:
print(f'** deleting file {weights}')
Path(weights).unlink(missing_ok=True)
elif path:
print(f'** deleting directory {path}')
rmtree(path,ignore_errors=True)
elif repo_id:
print(f'** deleting the cached model directory for {repo_id}')
self._delete_model_from_cache(repo_id)
def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None: def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None:
''' '''
@ -412,7 +430,7 @@ class ModelManager(object):
safety_checker=None, safety_checker=None,
local_files_only=not Globals.internet_available local_files_only=not Globals.internet_available
) )
if 'vae' in mconfig: if 'vae' in mconfig and mconfig['vae'] is not None:
vae = self._load_vae(mconfig['vae']) vae = self._load_vae(mconfig['vae'])
pipeline_args.update(vae=vae) pipeline_args.update(vae=vae)
if not isinstance(name_or_path,Path): if not isinstance(name_or_path,Path):
@ -518,11 +536,12 @@ class ModelManager(object):
print('>> Model scanned ok!') print('>> Model scanned ok!')
def import_diffuser_model(self, def import_diffuser_model(self,
repo_or_path:Union[str,Path], repo_or_path:Union[str,Path],
model_name:str=None, model_name:str=None,
description:str=None, description:str=None,
commit_to_conf:Path=None, vae:dict=None,
)->bool: commit_to_conf:Path=None,
)->bool:
''' '''
Attempts to install the indicated diffuser model and returns True if successful. Attempts to install the indicated diffuser model and returns True if successful.
@ -538,6 +557,7 @@ class ModelManager(object):
description = description or f'imported diffusers model {model_name}' description = description or f'imported diffusers model {model_name}'
new_config = dict( new_config = dict(
description=description, description=description,
vae=vae,
format='diffusers', format='diffusers',
) )
if isinstance(repo_or_path,Path) and repo_or_path.exists(): if isinstance(repo_or_path,Path) and repo_or_path.exists():
@ -551,18 +571,22 @@ class ModelManager(object):
return True return True
def import_ckpt_model(self, def import_ckpt_model(self,
weights:Union[str,Path], weights:Union[str,Path],
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml', config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
model_name:str=None, vae:Union[str,Path]=None,
model_description:str=None, model_name:str=None,
commit_to_conf:Path=None, model_description:str=None,
)->bool: commit_to_conf:Path=None,
)->bool:
''' '''
Attempts to install the indicated ckpt file and returns True if successful. Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file "weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model. or a http/https URL pointing to a remote model.
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
as the VAE for this model.
"config" is the model config file to use with this ckpt file. It defaults to "config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded. v1-inference.yaml. If a URL is provided, the config will be downloaded.
@ -589,6 +613,8 @@ class ModelManager(object):
width=512, width=512,
height=512 height=512
) )
if vae:
new_config['vae'] = vae
self.add_model(model_name, new_config, True) self.add_model(model_name, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
@ -670,16 +696,6 @@ class ModelManager(object):
print('done.') print('done.')
return new_config return new_config
def del_config(self, model_name:str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
gen.model_manager.del_model(model_name)
gen.model_manager.commit(opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)
def search_models(self, search_folder): def search_models(self, search_folder):
print(f'>> Finding Models In: {search_folder}') print(f'>> Finding Models In: {search_folder}')
models_folder_ckpt = Path(search_folder).glob('**/*.ckpt') models_folder_ckpt = Path(search_folder).glob('**/*.ckpt')
@ -761,7 +777,6 @@ class ModelManager(object):
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.') print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
print('** This is a quick one-time operation.') print('** This is a quick one-time operation.')
from shutil import move, rmtree
# transformer files get moved into the hub directory # transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present(): if cls._is_huggingface_hub_directory_present():
@ -977,6 +992,27 @@ class ModelManager(object):
return vae return vae
@staticmethod
def _delete_model_from_cache(repo_id):
cache_info = scan_cache_dir(global_cache_dir('diffusers'))
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
hashes_to_delete = set()
for repo in cache_info.repos:
if repo.repo_id==repo_id:
for revision in repo.revisions:
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(f'** deletion of this model is expected to free {strategy.expected_freed_size_str}')
strategy.execute()
@staticmethod
def _relativize(path:Union(str,Path))->Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root,path).resolve()
@staticmethod @staticmethod
def _is_huggingface_hub_directory_present() -> bool: def _is_huggingface_hub_directory_present() -> bool:
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None