A big refactor of model manager(according to IMHO)

This commit is contained in:
Sergey Borisov 2023-05-12 23:13:34 +03:00
parent 4492044d29
commit 131145eab1

View File

@ -133,6 +133,7 @@ from enum import Enum, auto
from pathlib import Path
from shutil import rmtree
from typing import Union, Callable, types
from contextlib import suppress
import safetensors
import safetensors.torch
@ -192,13 +193,13 @@ class ModelManager(object):
logger: types.ModuleType = logger
def __init__(
self,
config: Union[Path, DictConfig, str],
device_type: torch.device = CUDA_DEVICE,
precision: torch.dtype = torch.float16,
max_cache_size=MAX_CACHE_SIZE,
sequential_offload=False,
logger: types.ModuleType = logger,
self,
config: Union[Path, DictConfig, str],
device_type: torch.device = CUDA_DEVICE,
precision: torch.dtype = torch.float16,
max_cache_size=MAX_CACHE_SIZE,
sequential_offload=False,
logger: types.ModuleType = logger,
):
"""
Initialize with the path to the models.yaml config file.
@ -225,22 +226,36 @@ class ModelManager(object):
self.cache_keys = dict()
self.logger = logger
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
# TODO: rename to smth like - is_model_exists
def valid_model(
self,
model_name: str,
model_type: SDModelType = SDModelType.diffusers,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
try:
self._disambiguate_name(model_name, model_type)
return True
except InvalidModelError:
return False
model_key = self.create_key(model_name, model_class)
return model_key in self.config
def get_model(self,
model_name: str,
model_type: SDModelType=None,
submodel: SDModelType=None,
) -> SDModelInfo:
def create_key(self, model_name: str, model_type: SDModelType) -> str:
return f"{model_type.name}/{model_name}"
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__:
# TODO:
raise Exception(f"Unkown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model(
self,
model_name: str,
model_type: SDModelType=None,
submodel: SDModelType=None,
) -> SDModelInfo:
"""Given a model named identified in models.yaml, return
an SDModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml
@ -254,85 +269,77 @@ class ModelManager(object):
assume a diffusers pipeline. The behavior is illustrated here:
[models.yaml]
test1/diffusers:
diffusers/test1:
repo_id: foo/bar
format: diffusers
description: Typical diffusers pipeline
test1/lora:
lora/test1:
repo_id: /tmp/loras/test1.safetensors
format: lora
description: Typical lora file
test1_pipeline = mgr.get_model('test1')
# returns a StableDiffusionGeneratorPipeline
test1_vae1 = mgr.get_model('test1',submodel=SDModelType.vae)
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
# returns the VAE part of a diffusers model as an AutoencoderKL
test1_vae2 = mgr.get_model('test1',model_type=SDModelType.diffusers,submodel=SDModelType.vae)
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
# does the same thing as the previous statement. Note that model_type
# is for the parent model, and submodel is for the part
test1_lora = mgr.get_model('test1',model_type=SDModelType.lora)
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
# returns a LoRA embed (as a 'dict' of tensors)
test1_encoder = mgr.get_modelI('test1',model_type=SDModelType.textencoder)
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
# raises an InvalidModelError
"""
if not model_name:
model_name = self.default_model()
# TODO: delete default model or add check that this stable diffusion model
# if not model_name:
# model_name = self.default_model()
model_key = self._disambiguate_name(model_name, model_type)
model_key = self.create_key(model_name, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
)
# get the required loading info out of the config file
mconfig = self.config[model_key]
format = mconfig.get('format','diffusers')
if model_type and model_type.name != format:
raise InvalidModelError(
f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested'
)
model_parts = dict([(x.name,x) for x in SDModelType])
if format == 'diffusers':
# type already checked as it's part of key
if model_type == SDModelType.diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if 'weights' in mconfig:
location = self.convert_ckpt_and_cache(mconfig)
else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
elif format in model_parts:
location = global_resolve_path(mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights'))
else:
raise InvalidModelError(
f'"{model_key}" has an unknown format {format}'
location = global_resolve_path(
mconfig.get('path')) \
or mconfig.get('repo_id') \
or global_resolve_path(mconfig.get('weights')
)
model_type = model_parts[format]
subfolder = mconfig.get('subfolder')
revision = mconfig.get('revision')
hash = self.cache.model_hash(location,revision)
hash = self.cache.model_hash(location, revision)
# to support the traditional way of attaching a VAE
# to a model, we hacked in `attach_model_part`
vae = (None,None)
try:
vae = (None, None)
with suppress(Exception):
vae_id = mconfig.vae.repo_id
vae = (SDModelType.vae,vae_id)
except Exception:
pass
vae = (SDModelType.vae, vae_id)
model_context = self.cache.get_model(
location,
model_type = model_type,
revision = revision,
subfolder = subfolder,
submodel = submodel,
attach_model_part=vae,
attach_model_part = vae,
)
# in case we need to communicate information about this
@ -402,27 +409,28 @@ class ModelManager(object):
def list_models(self) -> dict:
"""
Return a dict of models in the format:
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
'description': description,
'format': ('ckpt'|'diffusers'|'vae'),
},
model_name2: { etc }
Return a dict of models
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf
object derived from models.yaml
"""
models = {}
for name in sorted(self.config, key=str.casefold):
stanza = self.config[name]
for model_key in sorted(self.config, key=str.casefold):
stanza = self.config[model_key]
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
models[name] = dict()
format = stanza.get("format", "ckpt") # Determine Format
model_name, model_type = self.parse_key(model_key)
models[model_name] = dict()
# TODO: return all models in future
if model_type != SDModelType.diffusers:
continue
model_format = "ckpt" if "weights" in stanza else "diffusers"
# Common Attribs
status = self.cache.status(
@ -431,37 +439,38 @@ class ModelManager(object):
subfolder=stanza.get('subfolder')
)
description = stanza.get("description", None)
models[name].update(
models[model_name].update(
description=description,
format=format,
type=model_type,
format=model_format,
status=status.value
)
# Checkpoint Config Parse
if format == "ckpt":
models[name].update(
config=str(stanza.get("config", None)),
weights=str(stanza.get("weights", None)),
vae=str(stanza.get("vae", None)),
width=str(stanza.get("width", 512)),
height=str(stanza.get("height", 512)),
if model_format == "ckpt":
models[model_name].update(
config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)),
width = str(stanza.get("width", 512)),
height = str(stanza.get("height", 512)),
)
# Diffusers Config Parse
if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig):
vae = dict(
repo_id=str(vae.get("repo_id", None)),
path=str(vae.get("path", None)),
subfolder=str(vae.get("subfolder", None)),
)
elif model_format == "diffusers":
if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig):
vae = dict(
repo_id = str(vae.get("repo_id", None)),
path = str(vae.get("path", None)),
subfolder = str(vae.get("subfolder", None)),
)
if format == "diffusers":
models[name].update(
vae=vae,
repo_id=str(stanza.get("repo_id", None)),
path=str(stanza.get("path", None)),
models[model_name].update(
vae = vae,
repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)),
)
return models
@ -472,44 +481,60 @@ class ModelManager(object):
"""
models = self.list_models()
for name in models:
if models[name]["format"] == "vae":
if models[name]["type"] == "vae":
continue
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}'
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
if models[name]["status"] == "active":
line = f"\033[1m{line}\033[0m"
print(line)
def del_model(self, model_name: str, model_type: SDModelType.diffusers, delete_files: bool = False):
def del_model(
self,
model_name: str,
model_type: SDModelType.diffusers,
delete_files: bool = False
):
"""
Delete the named model.
"""
model_name = self._disambiguate_name(model_name, model_type)
omega = self.config
if model_name not in omega:
self.logger.error(f"Unknown model {model_name}")
return
# save these for use in deletion later
conf = omega[model_name]
repo_id = conf.get("repo_id", None)
path = self._abs_path(conf.get("path", None))
weights = self._abs_path(conf.get("weights", None))
model_key = self.create_key(model_name, model_type)
model_cfg = self.pop(model_key, None)
if model_cfg is None:
self.logger.error(
f"Unknown model {model_key}"
)
return
# TODO: some legacy?
#if model_name in self.stack:
# self.stack.remove(model_name)
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
if delete_files:
if weights:
repo_id = conf.get("repo_id", None)
path = self._abs_path(conf.get("path", None))
weights = self._abs_path(conf.get("weights", None))
if "weights" in model_cfg:
weights = self._abs_path(model_cfg["weights"])
self.logger.info(f"Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
elif "path" in model_cfg:
path = self._abs_path(model_cfg["path"])
self.logger.info(f"Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
elif "repo_id" in model_cfg:
repo_id = model_cfg["repo_id"]
self.logger.info(f"Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
self, model_name: str, model_attributes: dict, clobber: bool = False
self,
model_name: str,
model_type: SDModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
@ -518,37 +543,47 @@ class ModelManager(object):
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
"""
omega = self.config
assert "format" in model_attributes, 'missing required field "format"'
if model_attributes["format"] == "diffusers":
assert (
"description" in model_attributes
), 'required field "description" is missing'
assert (
"path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined'
elif model_attributes["format"] == "ckpt":
for field in ("description", "weights", "height", "width", "config"):
assert field in model_attributes, f"required field {field} is missing"
if model_type == SDModelType.diffusers:
# TODO: automaticaly or manualy?
#assert "format" in model_attributes, 'missing required field "format"'
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
if model_format == "diffusers":
assert (
"description" in model_attributes
), 'required field "description" is missing'
assert (
"path" in model_attributes or "repo_id" in model_attributes
), 'model must have either the "path" or "repo_id" fields defined'
elif model_format == "ckpt":
for field in ("description", "weights", "height", "width", "config"):
assert field in model_attributes, f"required field {field} is missing"
else:
assert "weights" in model_attributes and "description" in model_attributes
model_key = f'{model_name}/{model_attributes["format"]}'
model_key = self.create_key(model_name, model_type)
assert (
clobber or model_key not in omega
clobber or model_key not in self.config
), f'attempt to overwrite existing model definition "{model_key}"'
omega[model_key] = model_attributes
self.config[model_key] = model_attributes
if "weights" in omega[model_key]:
omega[model_key]["weights"].replace("\\", "/")
if "weights" in self.config[model_key]:
self.config[model_key]["weights"].replace("\\", "/")
if clobber and model_key in self.cache_keys:
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
@ -587,10 +622,10 @@ class ModelManager(object):
return model_key
def import_lora(
self,
path: Path,
model_name: str=None,
description: str=None,
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated lora file. Call
@ -599,20 +634,21 @@ class ModelManager(object):
path = Path(path)
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(f'{model_name}/{SDModelType.lora.name}',
dict(
format="lora",
weights=str(path),
description=model_description,
),
True
)
self.add_model(
f'{model_name}/{SDModelType.lora.name}',
dict(
format="lora",
weights=str(path),
description=model_description,
),
True
)
def import_embedding(
self,
path: Path,
model_name: str=None,
description: str=None,
self,
path: Path,
model_name: str=None,
description: str=None,
):
"""
Creates an entry for the indicated lora file. Call
@ -626,14 +662,15 @@ class ModelManager(object):
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}',
dict(
format="textual_inversion",
weights=str(weights),
description=model_description,
),
True
)
self.add_model(
f'{model_name}/{SDModelType.textual_inversion.name}',
dict(
format="textual_inversion",
weights=str(weights),
description=model_description,
),
True
)
@classmethod
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
@ -857,7 +894,7 @@ class ModelManager(object):
)
return model_name
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
@ -872,6 +909,7 @@ class ModelManager(object):
return diffusers_path
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
@ -881,15 +919,16 @@ class ModelManager(object):
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(global_resolve_path(vae_ckpt_path)),
vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path
def _get_vae_for_conversion(self,
weights: Path,
mconfig: DictConfig
)->tuple(Path,SDModelType.vae):
def _get_vae_for_conversion(
self,
weights: Path,
mconfig: DictConfig
) -> Tuple[Path, SDModelType.vae]:
# VAE handling is convoluted
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
# it as the vae_path passed to convert
@ -1047,21 +1086,7 @@ class ModelManager(object):
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
)
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
model_type = model_type or SDModelType.diffusers
full_name = f"{model_name}/{model_type.name}"
if full_name in self.config:
return full_name
# special case - if diffusers requested, then allow name without type appended
if model_type==SDModelType.diffusers \
and model_name in self.config \
and self.config[model_name].format=='diffusers':
return model_name
raise InvalidModelError(
f'"{full_name}" is not a known model name. Please check your models.yaml file'
"""
)