A big refactor of model manager(according to IMHO)

This commit is contained in:
Sergey Borisov 2023-05-12 23:13:34 +03:00
parent 4492044d29
commit 131145eab1

View File

@ -133,6 +133,7 @@ from enum import Enum, auto
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Union, Callable, types from typing import Union, Callable, types
from contextlib import suppress
import safetensors import safetensors
import safetensors.torch import safetensors.torch
@ -225,18 +226,32 @@ class ModelManager(object):
self.cache_keys = dict() self.cache_keys = dict()
self.logger = logger self.logger = logger
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool: # TODO: rename to smth like - is_model_exists
def valid_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.diffusers,
) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
identifier. identifier.
""" """
try: model_key = self.create_key(model_name, model_class)
self._disambiguate_name(model_name, model_type) return model_key in self.config
return True
except InvalidModelError:
return False
def get_model(self, def create_key(self, model_name: str, model_type: SDModelType) -> str:
return f"{model_type.name}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__:
# TODO:
raise Exception(f"Unkown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model(
self,
model_name: str, model_name: str,
model_type: SDModelType=None, model_type: SDModelType=None,
submodel: SDModelType=None, submodel: SDModelType=None,
@ -254,85 +269,77 @@ class ModelManager(object):
assume a diffusers pipeline. The behavior is illustrated here: assume a diffusers pipeline. The behavior is illustrated here:
[models.yaml] [models.yaml]
test1/diffusers: diffusers/test1:
repo_id: foo/bar repo_id: foo/bar
format: diffusers
description: Typical diffusers pipeline description: Typical diffusers pipeline
test1/lora: lora/test1:
repo_id: /tmp/loras/test1.safetensors repo_id: /tmp/loras/test1.safetensors
format: lora
description: Typical lora file description: Typical lora file
test1_pipeline = mgr.get_model('test1') test1_pipeline = mgr.get_model('test1')
# returns a StableDiffusionGeneratorPipeline # returns a StableDiffusionGeneratorPipeline
test1_vae1 = mgr.get_model('test1',submodel=SDModelType.vae) test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
# returns the VAE part of a diffusers model as an AutoencoderKL # returns the VAE part of a diffusers model as an AutoencoderKL
test1_vae2 = mgr.get_model('test1',model_type=SDModelType.diffusers,submodel=SDModelType.vae) test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
# does the same thing as the previous statement. Note that model_type # does the same thing as the previous statement. Note that model_type
# is for the parent model, and submodel is for the part # is for the parent model, and submodel is for the part
test1_lora = mgr.get_model('test1',model_type=SDModelType.lora) test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
# returns a LoRA embed (as a 'dict' of tensors) # returns a LoRA embed (as a 'dict' of tensors)
test1_encoder = mgr.get_modelI('test1',model_type=SDModelType.textencoder) test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
# raises an InvalidModelError # raises an InvalidModelError
""" """
if not model_name: # TODO: delete default model or add check that this stable diffusion model
model_name = self.default_model() # if not model_name:
# model_name = self.default_model()
model_key = self._disambiguate_name(model_name, model_type) model_key = self.create_key(model_name, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
)
# get the required loading info out of the config file # get the required loading info out of the config file
mconfig = self.config[model_key] mconfig = self.config[model_key]
format = mconfig.get('format','diffusers') # type already checked as it's part of key
if model_type and model_type.name != format: if model_type == SDModelType.diffusers:
raise InvalidModelError(
f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested'
)
model_parts = dict([(x.name,x) for x in SDModelType])
if format == 'diffusers':
# intercept stanzas that point to checkpoint weights and replace them # intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model # with the equivalent diffusers model
if 'weights' in mconfig: if 'weights' in mconfig:
location = self.convert_ckpt_and_cache(mconfig) location = self.convert_ckpt_and_cache(mconfig)
else: else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
elif format in model_parts:
location = global_resolve_path(mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights'))
else: else:
raise InvalidModelError( location = global_resolve_path(
f'"{model_key}" has an unknown format {format}' mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights')
) )
model_type = model_parts[format]
subfolder = mconfig.get('subfolder') subfolder = mconfig.get('subfolder')
revision = mconfig.get('revision') revision = mconfig.get('revision')
hash = self.cache.model_hash(location,revision) hash = self.cache.model_hash(location, revision)
# to support the traditional way of attaching a VAE # to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part` # to a model, we hacked in `attach_model_part`
vae = (None,None) vae = (None, None)
try: with suppress(Exception):
vae_id = mconfig.vae.repo_id vae_id = mconfig.vae.repo_id
vae = (SDModelType.vae,vae_id) vae = (SDModelType.vae, vae_id)
except Exception:
pass
model_context = self.cache.get_model( model_context = self.cache.get_model(
location, location,
model_type = model_type, model_type = model_type,
revision = revision, revision = revision,
subfolder = subfolder, subfolder = subfolder,
submodel = submodel, submodel = submodel,
attach_model_part=vae, attach_model_part = vae,
) )
# in case we need to communicate information about this # in case we need to communicate information about this
@ -402,27 +409,28 @@ class ModelManager(object):
def list_models(self) -> dict: def list_models(self) -> dict:
""" """
Return a dict of models in the format: Return a dict of models
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
'description': description,
'format': ('ckpt'|'diffusers'|'vae'),
},
model_name2: { etc }
Please use model_manager.models() to get all the model names, Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf named 'model-name', and model_manager.config to get the full OmegaConf
object derived from models.yaml object derived from models.yaml
""" """
models = {} models = {}
for name in sorted(self.config, key=str.casefold): for model_key in sorted(self.config, key=str.casefold):
stanza = self.config[name] stanza = self.config[model_key]
# don't include VAEs in listing (legacy style) # don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]: if "config" in stanza and "/VAE/" in stanza["config"]:
continue continue
models[name] = dict() model_name, model_type = self.parse_key(model_key)
format = stanza.get("format", "ckpt") # Determine Format models[model_name] = dict()
# TODO: return all models in future
if model_type != SDModelType.diffusers:
continue
model_format = "ckpt" if "weights" in stanza else "diffusers"
# Common Attribs # Common Attribs
status = self.cache.status( status = self.cache.status(
@ -431,37 +439,38 @@ class ModelManager(object):
subfolder=stanza.get('subfolder') subfolder=stanza.get('subfolder')
) )
description = stanza.get("description", None) description = stanza.get("description", None)
models[name].update( models[model_name].update(
description=description, description=description,
format=format, type=model_type,
format=model_format,
status=status.value status=status.value
) )
# Checkpoint Config Parse # Checkpoint Config Parse
if format == "ckpt": if model_format == "ckpt":
models[name].update( models[model_name].update(
config=str(stanza.get("config", None)), config = str(stanza.get("config", None)),
weights=str(stanza.get("weights", None)), weights = str(stanza.get("weights", None)),
vae=str(stanza.get("vae", None)), vae = str(stanza.get("vae", None)),
width=str(stanza.get("width", 512)), width = str(stanza.get("width", 512)),
height=str(stanza.get("height", 512)), height = str(stanza.get("height", 512)),
) )
# Diffusers Config Parse # Diffusers Config Parse
elif model_format == "diffusers":
if vae := stanza.get("vae", None): if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig): if isinstance(vae, DictConfig):
vae = dict( vae = dict(
repo_id=str(vae.get("repo_id", None)), repo_id = str(vae.get("repo_id", None)),
path=str(vae.get("path", None)), path = str(vae.get("path", None)),
subfolder=str(vae.get("subfolder", None)), subfolder = str(vae.get("subfolder", None)),
) )
if format == "diffusers": models[model_name].update(
models[name].update( vae = vae,
vae=vae, repo_id = str(stanza.get("repo_id", None)),
repo_id=str(stanza.get("repo_id", None)), path = str(stanza.get("path", None)),
path=str(stanza.get("path", None)),
) )
return models return models
@ -472,44 +481,60 @@ class ModelManager(object):
""" """
models = self.list_models() models = self.list_models()
for name in models: for name in models:
if models[name]["format"] == "vae": if models[name]["type"] == "vae":
continue continue
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}' line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
if models[name]["status"] == "active": if models[name]["status"] == "active":
line = f"\033[1m{line}\033[0m" line = f"\033[1m{line}\033[0m"
print(line) print(line)
def del_model(self, model_name: str, model_type: SDModelType.diffusers, delete_files: bool = False): def del_model(
self,
model_name: str,
model_type: SDModelType.diffusers,
delete_files: bool = False
):
""" """
Delete the named model. Delete the named model.
""" """
model_name = self._disambiguate_name(model_name, model_type) model_key = self.create_key(model_name, model_type)
omega = self.config model_cfg = self.pop(model_key, None)
if model_name not in omega:
self.logger.error(f"Unknown model {model_name}") if model_cfg is None:
self.logger.error(
f"Unknown model {model_key}"
)
return return
# save these for use in deletion later
conf = omega[model_name] # TODO: some legacy?
#if model_name in self.stack:
# self.stack.remove(model_name)
if delete_files:
repo_id = conf.get("repo_id", None) repo_id = conf.get("repo_id", None)
path = self._abs_path(conf.get("path", None)) path = self._abs_path(conf.get("path", None))
weights = self._abs_path(conf.get("weights", None)) weights = self._abs_path(conf.get("weights", None))
if "weights" in model_cfg:
del omega[model_name] weights = self._abs_path(model_cfg["weights"])
if model_name in self.stack:
self.stack.remove(model_name)
if delete_files:
if weights:
self.logger.info(f"Deleting file {weights}") self.logger.info(f"Deleting file {weights}")
Path(weights).unlink(missing_ok=True) Path(weights).unlink(missing_ok=True)
elif path:
elif "path" in model_cfg:
path = self._abs_path(model_cfg["path"])
self.logger.info(f"Deleting directory {path}") self.logger.info(f"Deleting directory {path}")
rmtree(path, ignore_errors=True) rmtree(path, ignore_errors=True)
elif repo_id:
elif "repo_id" in model_cfg:
repo_id = model_cfg["repo_id"]
self.logger.info(f"Deleting the cached model directory for {repo_id}") self.logger.info(f"Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id) self._delete_model_from_cache(repo_id)
def add_model( def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
) -> None: ) -> None:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
@ -518,37 +543,47 @@ class ModelManager(object):
method will return True. Will fail with an assertion error if provided method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
""" """
omega = self.config
assert "format" in model_attributes, 'missing required field "format"' if model_type == SDModelType.diffusers:
if model_attributes["format"] == "diffusers": # TODO: automaticaly or manualy?
#assert "format" in model_attributes, 'missing required field "format"'
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
if model_format == "diffusers":
assert ( assert (
"description" in model_attributes "description" in model_attributes
), 'required field "description" is missing' ), 'required field "description" is missing'
assert ( assert (
"path" in model_attributes or "repo_id" in model_attributes "path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined' ), 'model must have either the "path" or "repo_id" fields defined'
elif model_attributes["format"] == "ckpt":
elif model_format == "ckpt":
for field in ("description", "weights", "height", "width", "config"): for field in ("description", "weights", "height", "width", "config"):
assert field in model_attributes, f"required field {field} is missing" assert field in model_attributes, f"required field {field} is missing"
else: else:
assert "weights" in model_attributes and "description" in model_attributes assert "weights" in model_attributes and "description" in model_attributes
model_key = f'{model_name}/{model_attributes["format"]}' model_key = self.create_key(model_name, model_type)
assert ( assert (
clobber or model_key not in omega clobber or model_key not in self.config
), f'attempt to overwrite existing model definition "{model_key}"' ), f'attempt to overwrite existing model definition "{model_key}"'
omega[model_key] = model_attributes self.config[model_key] = model_attributes
if "weights" in omega[model_key]: if "weights" in self.config[model_key]:
omega[model_key]["weights"].replace("\\", "/") self.config[model_key]["weights"].replace("\\", "/")
if clobber and model_key in self.cache_keys: if clobber and model_key in self.cache_keys:
self.cache.uncache_model(self.cache_keys[model_key]) self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key] del self.cache_keys[model_key]
def import_diffuser_model( def import_diffuser_model(
self, self,
repo_or_path: Union[str, Path], repo_or_path: Union[str, Path],
@ -599,7 +634,8 @@ class ModelManager(object):
path = Path(path) path = Path(path)
model_name = model_name or path.stem model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}" model_description = description or f"LoRA model {model_name}"
self.add_model(f'{model_name}/{SDModelType.lora.name}', self.add_model(
f'{model_name}/{SDModelType.lora.name}',
dict( dict(
format="lora", format="lora",
weights=str(path), weights=str(path),
@ -626,7 +662,8 @@ class ModelManager(object):
model_name = model_name or path.stem model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}" model_description = description or f"Textual embedding model {model_name}"
self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}', self.add_model(
f'{model_name}/{SDModelType.textual_inversion.name}',
dict( dict(
format="textual_inversion", format="textual_inversion",
weights=str(weights), weights=str(weights),
@ -857,7 +894,7 @@ class ModelManager(object):
) )
return model_name return model_name
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path: def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path:
""" """
Convert the checkpoint model indicated in mconfig into a Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted diffusers, cache it to disk, and return Path to converted
@ -872,6 +909,7 @@ class ModelManager(object):
return diffusers_path return diffusers_path
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors # to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings(): with SilenceWarnings():
@ -881,15 +919,16 @@ class ModelManager(object):
extract_ema=True, extract_ema=True,
original_config_file=config_file, original_config_file=config_file,
vae=vae_model, vae=vae_model,
vae_path=str(global_resolve_path(vae_ckpt_path)), vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
scan_needed=True, scan_needed=True,
) )
return diffusers_path return diffusers_path
def _get_vae_for_conversion(self, def _get_vae_for_conversion(
self,
weights: Path, weights: Path,
mconfig: DictConfig mconfig: DictConfig
)->tuple(Path,SDModelType.vae): ) -> Tuple[Path, SDModelType.vae]:
# VAE handling is convoluted # VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use # 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert # it as the vae_path passed to convert
@ -1050,20 +1089,6 @@ class ModelManager(object):
""" """
) )
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
model_type = model_type or SDModelType.diffusers
full_name = f"{model_name}/{model_type.name}"
if full_name in self.config:
return full_name
# special case - if diffusers requested, then allow name without type appended
if model_type==SDModelType.diffusers \
and model_name in self.config \
and self.config[model_name].format=='diffusers':
return model_name
raise InvalidModelError(
f'"{full_name}" is not a known model name. Please check your models.yaml file'
)
@classmethod @classmethod
def _delete_model_from_cache(cls,repo_id): def _delete_model_from_cache(cls,repo_id):