mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
A big refactor of model manager(according to IMHO)
This commit is contained in:
parent
4492044d29
commit
131145eab1
@ -133,6 +133,7 @@ from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Union, Callable, types
|
||||
from contextlib import suppress
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
@ -192,13 +193,13 @@ class ModelManager(object):
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Union[Path, DictConfig, str],
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_cache_size=MAX_CACHE_SIZE,
|
||||
sequential_offload=False,
|
||||
logger: types.ModuleType = logger,
|
||||
self,
|
||||
config: Union[Path, DictConfig, str],
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_cache_size=MAX_CACHE_SIZE,
|
||||
sequential_offload=False,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -225,22 +226,36 @@ class ModelManager(object):
|
||||
self.cache_keys = dict()
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
|
||||
# TODO: rename to smth like - is_model_exists
|
||||
def valid_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType = SDModelType.diffusers,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
try:
|
||||
self._disambiguate_name(model_name, model_type)
|
||||
return True
|
||||
except InvalidModelError:
|
||||
return False
|
||||
model_key = self.create_key(model_name, model_class)
|
||||
return model_key in self.config
|
||||
|
||||
def get_model(self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=None,
|
||||
submodel: SDModelType=None,
|
||||
) -> SDModelInfo:
|
||||
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
||||
return f"{model_type.name}/{model_name}"
|
||||
|
||||
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
|
||||
model_type_str, model_name = model_key.split('/', 1)
|
||||
if model_type_str not in SDModelType.__members__:
|
||||
# TODO:
|
||||
raise Exception(f"Unkown model type: {model_type_str}")
|
||||
|
||||
return (model_name, SDModelType[model_type_str])
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType=None,
|
||||
submodel: SDModelType=None,
|
||||
) -> SDModelInfo:
|
||||
"""Given a model named identified in models.yaml, return
|
||||
an SDModelInfo object describing it.
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
@ -254,85 +269,77 @@ class ModelManager(object):
|
||||
assume a diffusers pipeline. The behavior is illustrated here:
|
||||
|
||||
[models.yaml]
|
||||
test1/diffusers:
|
||||
diffusers/test1:
|
||||
repo_id: foo/bar
|
||||
format: diffusers
|
||||
description: Typical diffusers pipeline
|
||||
|
||||
test1/lora:
|
||||
lora/test1:
|
||||
repo_id: /tmp/loras/test1.safetensors
|
||||
format: lora
|
||||
description: Typical lora file
|
||||
|
||||
test1_pipeline = mgr.get_model('test1')
|
||||
# returns a StableDiffusionGeneratorPipeline
|
||||
|
||||
test1_vae1 = mgr.get_model('test1',submodel=SDModelType.vae)
|
||||
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
|
||||
# returns the VAE part of a diffusers model as an AutoencoderKL
|
||||
|
||||
test1_vae2 = mgr.get_model('test1',model_type=SDModelType.diffusers,submodel=SDModelType.vae)
|
||||
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
|
||||
# does the same thing as the previous statement. Note that model_type
|
||||
# is for the parent model, and submodel is for the part
|
||||
|
||||
test1_lora = mgr.get_model('test1',model_type=SDModelType.lora)
|
||||
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
|
||||
# returns a LoRA embed (as a 'dict' of tensors)
|
||||
|
||||
test1_encoder = mgr.get_modelI('test1',model_type=SDModelType.textencoder)
|
||||
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
|
||||
# raises an InvalidModelError
|
||||
|
||||
"""
|
||||
if not model_name:
|
||||
model_name = self.default_model()
|
||||
# TODO: delete default model or add check that this stable diffusion model
|
||||
# if not model_name:
|
||||
# model_name = self.default_model()
|
||||
|
||||
model_key = self._disambiguate_name(model_name, model_type)
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
if model_key not in self.config:
|
||||
raise InvalidModelError(
|
||||
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
|
||||
# get the required loading info out of the config file
|
||||
mconfig = self.config[model_key]
|
||||
|
||||
format = mconfig.get('format','diffusers')
|
||||
if model_type and model_type.name != format:
|
||||
raise InvalidModelError(
|
||||
f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested'
|
||||
)
|
||||
|
||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||
|
||||
if format == 'diffusers':
|
||||
# type already checked as it's part of key
|
||||
if model_type == SDModelType.diffusers:
|
||||
# intercept stanzas that point to checkpoint weights and replace them
|
||||
# with the equivalent diffusers model
|
||||
if 'weights' in mconfig:
|
||||
location = self.convert_ckpt_and_cache(mconfig)
|
||||
else:
|
||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||
elif format in model_parts:
|
||||
location = global_resolve_path(mconfig.get('path')) \
|
||||
or mconfig.get('repo_id') \
|
||||
or global_resolve_path(mconfig.get('weights'))
|
||||
else:
|
||||
raise InvalidModelError(
|
||||
f'"{model_key}" has an unknown format {format}'
|
||||
location = global_resolve_path(
|
||||
mconfig.get('path')) \
|
||||
or mconfig.get('repo_id') \
|
||||
or global_resolve_path(mconfig.get('weights')
|
||||
)
|
||||
|
||||
model_type = model_parts[format]
|
||||
subfolder = mconfig.get('subfolder')
|
||||
revision = mconfig.get('revision')
|
||||
hash = self.cache.model_hash(location,revision)
|
||||
hash = self.cache.model_hash(location, revision)
|
||||
|
||||
# to support the traditional way of attaching a VAE
|
||||
# to a model, we hacked in `attach_model_part`
|
||||
vae = (None,None)
|
||||
try:
|
||||
vae = (None, None)
|
||||
with suppress(Exception):
|
||||
vae_id = mconfig.vae.repo_id
|
||||
vae = (SDModelType.vae,vae_id)
|
||||
except Exception:
|
||||
pass
|
||||
vae = (SDModelType.vae, vae_id)
|
||||
|
||||
model_context = self.cache.get_model(
|
||||
location,
|
||||
model_type = model_type,
|
||||
revision = revision,
|
||||
subfolder = subfolder,
|
||||
submodel = submodel,
|
||||
attach_model_part=vae,
|
||||
attach_model_part = vae,
|
||||
)
|
||||
|
||||
# in case we need to communicate information about this
|
||||
@ -402,27 +409,28 @@ class ModelManager(object):
|
||||
|
||||
def list_models(self) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||
'description': description,
|
||||
'format': ('ckpt'|'diffusers'|'vae'),
|
||||
},
|
||||
model_name2: { etc }
|
||||
Return a dict of models
|
||||
Please use model_manager.models() to get all the model names,
|
||||
model_manager.model_info('model-name') to get the stanza for the model
|
||||
named 'model-name', and model_manager.config to get the full OmegaConf
|
||||
object derived from models.yaml
|
||||
"""
|
||||
models = {}
|
||||
for name in sorted(self.config, key=str.casefold):
|
||||
stanza = self.config[name]
|
||||
for model_key in sorted(self.config, key=str.casefold):
|
||||
stanza = self.config[model_key]
|
||||
|
||||
# don't include VAEs in listing (legacy style)
|
||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
||||
continue
|
||||
|
||||
models[name] = dict()
|
||||
format = stanza.get("format", "ckpt") # Determine Format
|
||||
model_name, model_type = self.parse_key(model_key)
|
||||
models[model_name] = dict()
|
||||
|
||||
# TODO: return all models in future
|
||||
if model_type != SDModelType.diffusers:
|
||||
continue
|
||||
|
||||
model_format = "ckpt" if "weights" in stanza else "diffusers"
|
||||
|
||||
# Common Attribs
|
||||
status = self.cache.status(
|
||||
@ -431,37 +439,38 @@ class ModelManager(object):
|
||||
subfolder=stanza.get('subfolder')
|
||||
)
|
||||
description = stanza.get("description", None)
|
||||
models[name].update(
|
||||
models[model_name].update(
|
||||
description=description,
|
||||
format=format,
|
||||
type=model_type,
|
||||
format=model_format,
|
||||
status=status.value
|
||||
)
|
||||
|
||||
|
||||
# Checkpoint Config Parse
|
||||
if format == "ckpt":
|
||||
models[name].update(
|
||||
config=str(stanza.get("config", None)),
|
||||
weights=str(stanza.get("weights", None)),
|
||||
vae=str(stanza.get("vae", None)),
|
||||
width=str(stanza.get("width", 512)),
|
||||
height=str(stanza.get("height", 512)),
|
||||
if model_format == "ckpt":
|
||||
models[model_name].update(
|
||||
config = str(stanza.get("config", None)),
|
||||
weights = str(stanza.get("weights", None)),
|
||||
vae = str(stanza.get("vae", None)),
|
||||
width = str(stanza.get("width", 512)),
|
||||
height = str(stanza.get("height", 512)),
|
||||
)
|
||||
|
||||
# Diffusers Config Parse
|
||||
if vae := stanza.get("vae", None):
|
||||
if isinstance(vae, DictConfig):
|
||||
vae = dict(
|
||||
repo_id=str(vae.get("repo_id", None)),
|
||||
path=str(vae.get("path", None)),
|
||||
subfolder=str(vae.get("subfolder", None)),
|
||||
)
|
||||
elif model_format == "diffusers":
|
||||
if vae := stanza.get("vae", None):
|
||||
if isinstance(vae, DictConfig):
|
||||
vae = dict(
|
||||
repo_id = str(vae.get("repo_id", None)),
|
||||
path = str(vae.get("path", None)),
|
||||
subfolder = str(vae.get("subfolder", None)),
|
||||
)
|
||||
|
||||
if format == "diffusers":
|
||||
models[name].update(
|
||||
vae=vae,
|
||||
repo_id=str(stanza.get("repo_id", None)),
|
||||
path=str(stanza.get("path", None)),
|
||||
models[model_name].update(
|
||||
vae = vae,
|
||||
repo_id = str(stanza.get("repo_id", None)),
|
||||
path = str(stanza.get("path", None)),
|
||||
)
|
||||
|
||||
return models
|
||||
@ -472,44 +481,60 @@ class ModelManager(object):
|
||||
"""
|
||||
models = self.list_models()
|
||||
for name in models:
|
||||
if models[name]["format"] == "vae":
|
||||
if models[name]["type"] == "vae":
|
||||
continue
|
||||
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}'
|
||||
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
|
||||
if models[name]["status"] == "active":
|
||||
line = f"\033[1m{line}\033[0m"
|
||||
print(line)
|
||||
|
||||
def del_model(self, model_name: str, model_type: SDModelType.diffusers, delete_files: bool = False):
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType.diffusers,
|
||||
delete_files: bool = False
|
||||
):
|
||||
"""
|
||||
Delete the named model.
|
||||
"""
|
||||
model_name = self._disambiguate_name(model_name, model_type)
|
||||
omega = self.config
|
||||
if model_name not in omega:
|
||||
self.logger.error(f"Unknown model {model_name}")
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
repo_id = conf.get("repo_id", None)
|
||||
path = self._abs_path(conf.get("path", None))
|
||||
weights = self._abs_path(conf.get("weights", None))
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
model_cfg = self.pop(model_key, None)
|
||||
|
||||
if model_cfg is None:
|
||||
self.logger.error(
|
||||
f"Unknown model {model_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# TODO: some legacy?
|
||||
#if model_name in self.stack:
|
||||
# self.stack.remove(model_name)
|
||||
|
||||
del omega[model_name]
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
repo_id = conf.get("repo_id", None)
|
||||
path = self._abs_path(conf.get("path", None))
|
||||
weights = self._abs_path(conf.get("weights", None))
|
||||
if "weights" in model_cfg:
|
||||
weights = self._abs_path(model_cfg["weights"])
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
|
||||
elif "path" in model_cfg:
|
||||
path = self._abs_path(model_cfg["path"])
|
||||
self.logger.info(f"Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
elif repo_id:
|
||||
|
||||
elif "repo_id" in model_cfg:
|
||||
repo_id = model_cfg["repo_id"]
|
||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(
|
||||
self, model_name: str, model_attributes: dict, clobber: bool = False
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: SDModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
@ -518,37 +543,47 @@ class ModelManager(object):
|
||||
method will return True. Will fail with an assertion error if provided
|
||||
attributes are incorrect or the model name is missing.
|
||||
"""
|
||||
omega = self.config
|
||||
|
||||
assert "format" in model_attributes, 'missing required field "format"'
|
||||
if model_attributes["format"] == "diffusers":
|
||||
assert (
|
||||
"description" in model_attributes
|
||||
), 'required field "description" is missing'
|
||||
assert (
|
||||
"path" in model_attributes or "repo_id" in model_attributes
|
||||
), 'model must have either the "path" or "repo_id" fields defined'
|
||||
elif model_attributes["format"] == "ckpt":
|
||||
for field in ("description", "weights", "height", "width", "config"):
|
||||
assert field in model_attributes, f"required field {field} is missing"
|
||||
if model_type == SDModelType.diffusers:
|
||||
# TODO: automaticaly or manualy?
|
||||
#assert "format" in model_attributes, 'missing required field "format"'
|
||||
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
|
||||
|
||||
if model_format == "diffusers":
|
||||
assert (
|
||||
"description" in model_attributes
|
||||
), 'required field "description" is missing'
|
||||
assert (
|
||||
"path" in model_attributes or "repo_id" in model_attributes
|
||||
), 'model must have either the "path" or "repo_id" fields defined'
|
||||
|
||||
elif model_format == "ckpt":
|
||||
for field in ("description", "weights", "height", "width", "config"):
|
||||
assert field in model_attributes, f"required field {field} is missing"
|
||||
|
||||
else:
|
||||
assert "weights" in model_attributes and "description" in model_attributes
|
||||
|
||||
model_key = f'{model_name}/{model_attributes["format"]}'
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
|
||||
assert (
|
||||
clobber or model_key not in omega
|
||||
clobber or model_key not in self.config
|
||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
||||
|
||||
omega[model_key] = model_attributes
|
||||
self.config[model_key] = model_attributes
|
||||
|
||||
if "weights" in omega[model_key]:
|
||||
omega[model_key]["weights"].replace("\\", "/")
|
||||
if "weights" in self.config[model_key]:
|
||||
self.config[model_key]["weights"].replace("\\", "/")
|
||||
|
||||
if clobber and model_key in self.cache_keys:
|
||||
self.cache.uncache_model(self.cache_keys[model_key])
|
||||
del self.cache_keys[model_key]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
@ -587,10 +622,10 @@ class ModelManager(object):
|
||||
return model_key
|
||||
|
||||
def import_lora(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -599,20 +634,21 @@ class ModelManager(object):
|
||||
path = Path(path)
|
||||
model_name = model_name or path.stem
|
||||
model_description = description or f"LoRA model {model_name}"
|
||||
self.add_model(f'{model_name}/{SDModelType.lora.name}',
|
||||
dict(
|
||||
format="lora",
|
||||
weights=str(path),
|
||||
description=model_description,
|
||||
),
|
||||
True
|
||||
)
|
||||
self.add_model(
|
||||
f'{model_name}/{SDModelType.lora.name}',
|
||||
dict(
|
||||
format="lora",
|
||||
weights=str(path),
|
||||
description=model_description,
|
||||
),
|
||||
True
|
||||
)
|
||||
|
||||
def import_embedding(
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
self,
|
||||
path: Path,
|
||||
model_name: str=None,
|
||||
description: str=None,
|
||||
):
|
||||
"""
|
||||
Creates an entry for the indicated lora file. Call
|
||||
@ -626,14 +662,15 @@ class ModelManager(object):
|
||||
|
||||
model_name = model_name or path.stem
|
||||
model_description = description or f"Textual embedding model {model_name}"
|
||||
self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}',
|
||||
dict(
|
||||
format="textual_inversion",
|
||||
weights=str(weights),
|
||||
description=model_description,
|
||||
),
|
||||
True
|
||||
)
|
||||
self.add_model(
|
||||
f'{model_name}/{SDModelType.textual_inversion.name}',
|
||||
dict(
|
||||
format="textual_inversion",
|
||||
weights=str(weights),
|
||||
description=model_description,
|
||||
),
|
||||
True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||
@ -857,7 +894,7 @@ class ModelManager(object):
|
||||
)
|
||||
return model_name
|
||||
|
||||
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
|
||||
def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
@ -872,6 +909,7 @@ class ModelManager(object):
|
||||
return diffusers_path
|
||||
|
||||
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||
|
||||
# to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings():
|
||||
@ -881,15 +919,16 @@ class ModelManager(object):
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
vae=vae_model,
|
||||
vae_path=str(global_resolve_path(vae_ckpt_path)),
|
||||
vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
|
||||
scan_needed=True,
|
||||
)
|
||||
return diffusers_path
|
||||
|
||||
def _get_vae_for_conversion(self,
|
||||
weights: Path,
|
||||
mconfig: DictConfig
|
||||
)->tuple(Path,SDModelType.vae):
|
||||
def _get_vae_for_conversion(
|
||||
self,
|
||||
weights: Path,
|
||||
mconfig: DictConfig
|
||||
) -> Tuple[Path, SDModelType.vae]:
|
||||
# VAE handling is convoluted
|
||||
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
||||
# it as the vae_path passed to convert
|
||||
@ -1047,21 +1086,7 @@ class ModelManager(object):
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
)
|
||||
|
||||
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
|
||||
model_type = model_type or SDModelType.diffusers
|
||||
full_name = f"{model_name}/{model_type.name}"
|
||||
if full_name in self.config:
|
||||
return full_name
|
||||
# special case - if diffusers requested, then allow name without type appended
|
||||
if model_type==SDModelType.diffusers \
|
||||
and model_name in self.config \
|
||||
and self.config[model_name].format=='diffusers':
|
||||
return model_name
|
||||
raise InvalidModelError(
|
||||
f'"{full_name}" is not a known model name. Please check your models.yaml file'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user