mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix two bugs in model import
1. !import_model did not allow user to specify VAE file. This is now fixed. 2. !del_model did not offer the user the opportunity to delete the underlying weights file or diffusers directory. This is now fixed.
This commit is contained in:
parent
e11f15cf78
commit
7bd2220a24
@ -124,7 +124,7 @@ def main():
|
||||
# preload the model
|
||||
try:
|
||||
gen.load_model()
|
||||
except KeyError as e:
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
report_model_error(opt, e)
|
||||
@ -589,7 +589,7 @@ def import_model(model_path:str, gen, opt, completer):
|
||||
gen.model_manager.del_model(model_name)
|
||||
return
|
||||
|
||||
if input('Make this the default model? [n] ') in ('y','Y'):
|
||||
if input('Make this the default model? [n] ').strip() in ('y','Y'):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
@ -606,10 +606,14 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
vae = None
|
||||
if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-se"? [n] ').strip() in ('y','Y'):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
path_or_repo,
|
||||
model_name = model_name,
|
||||
vae = vae,
|
||||
description = model_description):
|
||||
print('** model failed to import')
|
||||
return None
|
||||
@ -627,17 +631,28 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||
)
|
||||
config_file = None
|
||||
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
done = os.path.exists(config_file)
|
||||
|
||||
completer.complete_extensions(('.ckpt','.safetensors'))
|
||||
vae = None
|
||||
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
vae = input('VAE file for this model (leave blank for none): ').strip() or None
|
||||
done = (not vae) or os.path.exists(vae)
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if not manager.import_ckpt_model(
|
||||
path_or_url,
|
||||
config = config_file,
|
||||
vae = vae,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf,
|
||||
@ -709,7 +724,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if input(f'Load optimized model {model_name}? [y] ') not in ('n','N'):
|
||||
if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'):
|
||||
gen.set_model(model_name)
|
||||
|
||||
response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
|
||||
@ -725,17 +740,17 @@ def del_config(model_name:str, gen, opt, completer):
|
||||
if model_name not in gen.model_manager.config:
|
||||
print(f"** Unknown model {model_name}")
|
||||
return
|
||||
gen.model_manager.del_model(model_name)
|
||||
|
||||
if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')):
|
||||
return
|
||||
|
||||
delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y'))
|
||||
gen.model_manager.del_model(model_name,delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
|
||||
def edit_model(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
# if model_name == current_model:
|
||||
# print("** Can't edit the active model. !switch to another model first. **")
|
||||
# return
|
||||
|
||||
manager = gen.model_manager
|
||||
if not (info := manager.model_info(model_name)):
|
||||
print(f'** Unknown model {model_name}')
|
||||
|
@ -18,7 +18,9 @@ import traceback
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Union, Any
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from ldm.util import download_with_progress_bar
|
||||
|
||||
import torch
|
||||
@ -225,7 +227,7 @@ class ModelManager(object):
|
||||
line = f'\033[1m{line}\033[0m'
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name:str) -> None:
|
||||
def del_model(self, model_name:str, delete_files:bool=False) -> None:
|
||||
'''
|
||||
Delete the named model.
|
||||
'''
|
||||
@ -233,9 +235,25 @@ class ModelManager(object):
|
||||
if model_name not in omega:
|
||||
print(f'** Unknown model {model_name}')
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
repo_id = conf.get('repo_id',None)
|
||||
path = self._relativize(conf.get('path',None))
|
||||
weights = self._relativize(conf.get('weights',None))
|
||||
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
print(f'** deleting file {weights}')
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
print(f'** deleting directory {path}')
|
||||
rmtree(path,ignore_errors=True)
|
||||
elif repo_id:
|
||||
print(f'** deleting the cached model directory for {repo_id}')
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None:
|
||||
'''
|
||||
@ -412,7 +430,7 @@ class ModelManager(object):
|
||||
safety_checker=None,
|
||||
local_files_only=not Globals.internet_available
|
||||
)
|
||||
if 'vae' in mconfig:
|
||||
if 'vae' in mconfig and mconfig['vae'] is not None:
|
||||
vae = self._load_vae(mconfig['vae'])
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path,Path):
|
||||
@ -518,11 +536,12 @@ class ModelManager(object):
|
||||
print('>> Model scanned ok!')
|
||||
|
||||
def import_diffuser_model(self,
|
||||
repo_or_path:Union[str,Path],
|
||||
model_name:str=None,
|
||||
description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
repo_or_path:Union[str,Path],
|
||||
model_name:str=None,
|
||||
description:str=None,
|
||||
vae:dict=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
'''
|
||||
Attempts to install the indicated diffuser model and returns True if successful.
|
||||
|
||||
@ -538,6 +557,7 @@ class ModelManager(object):
|
||||
description = description or f'imported diffusers model {model_name}'
|
||||
new_config = dict(
|
||||
description=description,
|
||||
vae=vae,
|
||||
format='diffusers',
|
||||
)
|
||||
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||
@ -551,18 +571,22 @@ class ModelManager(object):
|
||||
return True
|
||||
|
||||
def import_ckpt_model(self,
|
||||
weights:Union[str,Path],
|
||||
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
|
||||
model_name:str=None,
|
||||
model_description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
weights:Union[str,Path],
|
||||
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
|
||||
vae:Union[str,Path]=None,
|
||||
model_name:str=None,
|
||||
model_description:str=None,
|
||||
commit_to_conf:Path=None,
|
||||
)->bool:
|
||||
'''
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||
or a http/https URL pointing to a remote model.
|
||||
|
||||
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
|
||||
as the VAE for this model.
|
||||
|
||||
"config" is the model config file to use with this ckpt file. It defaults to
|
||||
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||
|
||||
@ -589,6 +613,8 @@ class ModelManager(object):
|
||||
width=512,
|
||||
height=512
|
||||
)
|
||||
if vae:
|
||||
new_config['vae'] = vae
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
@ -670,16 +696,6 @@ class ModelManager(object):
|
||||
print('done.')
|
||||
return new_config
|
||||
|
||||
def del_config(self, model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
gen.model_manager.del_model(model_name)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
completer.del_model(model_name)
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f'>> Finding Models In: {search_folder}')
|
||||
models_folder_ckpt = Path(search_folder).glob('**/*.ckpt')
|
||||
@ -761,7 +777,6 @@ class ModelManager(object):
|
||||
|
||||
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
|
||||
print('** This is a quick one-time operation.')
|
||||
from shutil import move, rmtree
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
@ -977,6 +992,27 @@ class ModelManager(object):
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir('diffusers'))
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
# but the code quickly became incomprehensible!
|
||||
hashes_to_delete = set()
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id==repo_id:
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
print(f'** deletion of this model is expected to free {strategy.expected_freed_size_str}')
|
||||
strategy.execute()
|
||||
|
||||
@staticmethod
|
||||
def _relativize(path:Union(str,Path))->Path:
|
||||
if path is None or Path(path).is_absolute():
|
||||
return path
|
||||
return Path(Globals.root,path).resolve()
|
||||
|
||||
@staticmethod
|
||||
def _is_huggingface_hub_directory_present() -> bool:
|
||||
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None
|
||||
|
Loading…
Reference in New Issue
Block a user