A big refactor of model manager(according to IMHO)

This commit is contained in:
Sergey Borisov 2023-05-12 23:13:34 +03:00
parent 4492044d29
commit 131145eab1

View File

@ -133,6 +133,7 @@ from enum import Enum, auto
from pathlib import Path
from shutil import rmtree
from typing import Union, Callable, types
from contextlib import suppress
import safetensors
import safetensors.torch
@ -225,18 +226,32 @@ class ModelManager(object):
self.cache_keys = dict()
self.logger = logger
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
# TODO: rename to smth like - is_model_exists
def valid_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.diffusers,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
try:
self._disambiguate_name(model_name, model_type)
return True
except InvalidModelError:
return False
model_key = self.create_key(model_name, model_class)
return model_key in self.config
def get_model(self,
def create_key(self, model_name: str, model_type: SDModelType) -> str:
return f"{model_type.name}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__:
# TODO:
raise Exception(f"Unkown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model(
self,
model_name: str,
model_type: SDModelType=None,
submodel: SDModelType=None,
@ -254,14 +269,12 @@ class ModelManager(object):
assume a diffusers pipeline. The behavior is illustrated here:
[models.yaml]
test1/diffusers:
diffusers/test1:
repo_id: foo/bar
format: diffusers
description: Typical diffusers pipeline
test1/lora:
lora/test1:
repo_id: /tmp/loras/test1.safetensors
format: lora
description: Typical lora file
test1_pipeline = mgr.get_model('test1')
@ -281,39 +294,34 @@ class ModelManager(object):
# raises an InvalidModelError
"""
if not model_name:
model_name = self.default_model()
# TODO: delete default model or add check that this stable diffusion model
# if not model_name:
# model_name = self.default_model()
model_key = self._disambiguate_name(model_name, model_type)
model_key = self.create_key(model_name, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
)
# get the required loading info out of the config file
mconfig = self.config[model_key]
format = mconfig.get('format','diffusers')
if model_type and model_type.name != format:
raise InvalidModelError(
f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested'
)
model_parts = dict([(x.name,x) for x in SDModelType])
if format == 'diffusers':
# type already checked as it's part of key
if model_type == SDModelType.diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if 'weights' in mconfig:
location = self.convert_ckpt_and_cache(mconfig)
else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
elif format in model_parts:
location = global_resolve_path(mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights'))
else:
raise InvalidModelError(
f'"{model_key}" has an unknown format {format}'
location = global_resolve_path(
mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights')
)
model_type = model_parts[format]
subfolder = mconfig.get('subfolder')
revision = mconfig.get('revision')
hash = self.cache.model_hash(location, revision)
@ -321,11 +329,10 @@ class ModelManager(object):
# to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part`
vae = (None, None)
try:
with suppress(Exception):
vae_id = mconfig.vae.repo_id
vae = (SDModelType.vae, vae_id)
except Exception:
pass
model_context = self.cache.get_model(
location,
model_type = model_type,
@ -402,27 +409,28 @@ class ModelManager(object):
def list_models(self) -> dict:
"""
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
'description': description,
'format': ('ckpt'|'diffusers'|'vae'),
},
model_name2: { etc }
Return a dict of models
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf
object derived from models.yaml
"""
models = {}
for name in sorted(self.config, key=str.casefold):
stanza = self.config[name]
for model_key in sorted(self.config, key=str.casefold):
stanza = self.config[model_key]
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
models[name] = dict()
format = stanza.get("format", "ckpt") # Determine Format
model_name, model_type = self.parse_key(model_key)
models[model_name] = dict()
# TODO: return all models in future
if model_type != SDModelType.diffusers:
continue
model_format = "ckpt" if "weights" in stanza else "diffusers"
# Common Attribs
status = self.cache.status(
@ -431,16 +439,17 @@ class ModelManager(object):
subfolder=stanza.get('subfolder')
)
description = stanza.get("description", None)
models[name].update(
models[model_name].update(
description=description,
format=format,
type=model_type,
format=model_format,
status=status.value
)
# Checkpoint Config Parse
if format == "ckpt":
models[name].update(
if model_format == "ckpt":
models[model_name].update(
config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)),
@ -449,6 +458,7 @@ class ModelManager(object):
)
# Diffusers Config Parse
elif model_format == "diffusers":
if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig):
vae = dict(
@ -457,8 +467,7 @@ class ModelManager(object):
subfolder = str(vae.get("subfolder", None)),
)
if format == "diffusers":
models[name].update(
models[model_name].update(
vae = vae,
repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)),
@ -472,44 +481,60 @@ class ModelManager(object):
"""
models = self.list_models()
for name in models:
if models[name]["format"] == "vae":
if models[name]["type"] == "vae":
continue
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}'
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
if models[name]["status"] == "active":
line = f"\033[1m{line}\033[0m"
print(line)
def del_model(self, model_name: str, model_type: SDModelType.diffusers, delete_files: bool = False):
def del_model(
self,
model_name: str,
model_type: SDModelType.diffusers,
delete_files: bool = False
):
"""
Delete the named model.
"""
model_name = self._disambiguate_name(model_name, model_type)
omega = self.config
if model_name not in omega:
self.logger.error(f"Unknown model {model_name}")
model_key = self.create_key(model_name, model_type)
model_cfg = self.pop(model_key, None)
if model_cfg is None:
self.logger.error(
f"Unknown model {model_key}"
)
return
# save these for use in deletion later
conf = omega[model_name]
# TODO: some legacy?
#if model_name in self.stack:
# self.stack.remove(model_name)
if delete_files:
repo_id = conf.get("repo_id", None)
path = self._abs_path(conf.get("path", None))
weights = self._abs_path(conf.get("weights", None))
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
if delete_files:
if weights:
if "weights" in model_cfg:
weights = self._abs_path(model_cfg["weights"])
self.logger.info(f"Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
elif "path" in model_cfg:
path = self._abs_path(model_cfg["path"])
self.logger.info(f"Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
elif "repo_id" in model_cfg:
repo_id = model_cfg["repo_id"]
self.logger.info(f"Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
@ -518,37 +543,47 @@ class ModelManager(object):
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
"""
omega = self.config
assert "format" in model_attributes, 'missing required field "format"'
if model_attributes["format"] == "diffusers":
if model_type == SDModelType.diffusers:
# TODO: automaticaly or manualy?
#assert "format" in model_attributes, 'missing required field "format"'
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
if model_format == "diffusers":
assert (
"description" in model_attributes
), 'required field "description" is missing'
assert (
"path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined'
elif model_attributes["format"] == "ckpt":
elif model_format == "ckpt":
for field in ("description", "weights", "height", "width", "config"):
assert field in model_attributes, f"required field {field} is missing"
else:
assert "weights" in model_attributes and "description" in model_attributes
model_key = f'{model_name}/{model_attributes["format"]}'
model_key = self.create_key(model_name, model_type)
assert (
clobber or model_key not in omega
clobber or model_key not in self.config
), f'attempt to overwrite existing model definition "{model_key}"'
omega[model_key] = model_attributes
self.config[model_key] = model_attributes
if "weights" in omega[model_key]:
omega[model_key]["weights"].replace("\\", "/")
if "weights" in self.config[model_key]:
self.config[model_key]["weights"].replace("\\", "/")
if clobber and model_key in self.cache_keys:
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
@ -599,7 +634,8 @@ class ModelManager(object):
path = Path(path)
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(f'{model_name}/{SDModelType.lora.name}',
self.add_model(
f'{model_name}/{SDModelType.lora.name}',
dict(
format="lora",
weights=str(path),
@ -626,7 +662,8 @@ class ModelManager(object):
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}',
self.add_model(
f'{model_name}/{SDModelType.textual_inversion.name}',
dict(
format="textual_inversion",
weights=str(weights),
@ -872,6 +909,7 @@ class ModelManager(object):
return diffusers_path
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
@ -881,15 +919,16 @@ class ModelManager(object):
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(global_resolve_path(vae_ckpt_path)),
vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path
def _get_vae_for_conversion(self,
def _get_vae_for_conversion(
self,
weights: Path,
mconfig: DictConfig
)->tuple(Path,SDModelType.vae):
) -> Tuple[Path, SDModelType.vae]:
# VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert
@ -1050,20 +1089,6 @@ class ModelManager(object):
"""
)
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
model_type = model_type or SDModelType.diffusers
full_name = f"{model_name}/{model_type.name}"
if full_name in self.config:
return full_name
# special case - if diffusers requested, then allow name without type appended
if model_type==SDModelType.diffusers \
and model_name in self.config \
and self.config[model_name].format=='diffusers':
return model_name
raise InvalidModelError(
f'"{full_name}" is not a known model name. Please check your models.yaml file'
)
@classmethod
def _delete_model_from_cache(cls,repo_id):