fix two bugs in model import

1. !import_model did not allow user to specify VAE file. This is now fixed.
2. !del_model did not offer the user the opportunity to delete the underlying
   weights file or diffusers directory. This is now fixed.
This commit is contained in:
Lincoln Stein
2023-01-19 01:30:58 -05:00
parent e11f15cf78
commit 7bd2220a24
2 changed files with 84 additions and 33 deletions

View File

@ -18,7 +18,9 @@ import traceback
import warnings
import safetensors.torch
from pathlib import Path
from shutil import move, rmtree
from typing import Union, Any
from huggingface_hub import scan_cache_dir
from ldm.util import download_with_progress_bar
import torch
@ -225,7 +227,7 @@ class ModelManager(object):
line = f'\033[1m{line}\033[0m'
print(line)
def del_model(self, model_name:str) -> None:
def del_model(self, model_name:str, delete_files:bool=False) -> None:
'''
Delete the named model.
'''
@ -233,9 +235,25 @@ class ModelManager(object):
if model_name not in omega:
print(f'** Unknown model {model_name}')
return
# save these for use in deletion later
conf = omega[model_name]
repo_id = conf.get('repo_id',None)
path = self._relativize(conf.get('path',None))
weights = self._relativize(conf.get('weights',None))
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
if delete_files:
if weights:
print(f'** deleting file {weights}')
Path(weights).unlink(missing_ok=True)
elif path:
print(f'** deleting directory {path}')
rmtree(path,ignore_errors=True)
elif repo_id:
print(f'** deleting the cached model directory for {repo_id}')
self._delete_model_from_cache(repo_id)
def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None:
'''
@ -412,7 +430,7 @@ class ModelManager(object):
safety_checker=None,
local_files_only=not Globals.internet_available
)
if 'vae' in mconfig:
if 'vae' in mconfig and mconfig['vae'] is not None:
vae = self._load_vae(mconfig['vae'])
pipeline_args.update(vae=vae)
if not isinstance(name_or_path,Path):
@ -518,11 +536,12 @@ class ModelManager(object):
print('>> Model scanned ok!')
def import_diffuser_model(self,
repo_or_path:Union[str,Path],
model_name:str=None,
description:str=None,
commit_to_conf:Path=None,
)->bool:
repo_or_path:Union[str,Path],
model_name:str=None,
description:str=None,
vae:dict=None,
commit_to_conf:Path=None,
)->bool:
'''
Attempts to install the indicated diffuser model and returns True if successful.
@ -538,6 +557,7 @@ class ModelManager(object):
description = description or f'imported diffusers model {model_name}'
new_config = dict(
description=description,
vae=vae,
format='diffusers',
)
if isinstance(repo_or_path,Path) and repo_or_path.exists():
@ -551,18 +571,22 @@ class ModelManager(object):
return True
def import_ckpt_model(self,
weights:Union[str,Path],
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
model_name:str=None,
model_description:str=None,
commit_to_conf:Path=None,
)->bool:
weights:Union[str,Path],
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
vae:Union[str,Path]=None,
model_name:str=None,
model_description:str=None,
commit_to_conf:Path=None,
)->bool:
'''
Attempts to install the indicated ckpt file and returns True if successful.
"weights" can be either a path-like object corresponding to a local .ckpt file
or a http/https URL pointing to a remote model.
"vae" is a Path or str object pointing to a ckpt or safetensors file to be used
as the VAE for this model.
"config" is the model config file to use with this ckpt file. It defaults to
v1-inference.yaml. If a URL is provided, the config will be downloaded.
@ -589,6 +613,8 @@ class ModelManager(object):
width=512,
height=512
)
if vae:
new_config['vae'] = vae
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
@ -670,16 +696,6 @@ class ModelManager(object):
print('done.')
return new_config
def del_config(self, model_name:str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
gen.model_manager.del_model(model_name)
gen.model_manager.commit(opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)
def search_models(self, search_folder):
print(f'>> Finding Models In: {search_folder}')
models_folder_ckpt = Path(search_folder).glob('**/*.ckpt')
@ -761,7 +777,6 @@ class ModelManager(object):
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
print('** This is a quick one-time operation.')
from shutil import move, rmtree
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
@ -977,6 +992,27 @@ class ModelManager(object):
return vae
@staticmethod
def _delete_model_from_cache(repo_id):
cache_info = scan_cache_dir(global_cache_dir('diffusers'))
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
hashes_to_delete = set()
for repo in cache_info.repos:
if repo.repo_id==repo_id:
for revision in repo.revisions:
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(f'** deletion of this model is expected to free {strategy.expected_freed_size_str}')
strategy.execute()
@staticmethod
def _relativize(path:Union(str,Path))->Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root,path).resolve()
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:
return os.getenv('HF_HOME') is not None or os.getenv('XDG_CACHE_HOME') is not None