mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
A big refactor of model manager(according to IMHO)
This commit is contained in:
parent
4492044d29
commit
131145eab1
@ -133,6 +133,7 @@ from enum import Enum, auto
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Union, Callable, types
|
from typing import Union, Callable, types
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -192,13 +193,13 @@ class ModelManager(object):
|
|||||||
logger: types.ModuleType = logger
|
logger: types.ModuleType = logger
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Union[Path, DictConfig, str],
|
config: Union[Path, DictConfig, str],
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
max_cache_size=MAX_CACHE_SIZE,
|
max_cache_size=MAX_CACHE_SIZE,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
logger: types.ModuleType = logger,
|
logger: types.ModuleType = logger,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file.
|
Initialize with the path to the models.yaml config file.
|
||||||
@ -225,22 +226,36 @@ class ModelManager(object):
|
|||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
|
# TODO: rename to smth like - is_model_exists
|
||||||
|
def valid_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType = SDModelType.diffusers,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid
|
||||||
identifier.
|
identifier.
|
||||||
"""
|
"""
|
||||||
try:
|
model_key = self.create_key(model_name, model_class)
|
||||||
self._disambiguate_name(model_name, model_type)
|
return model_key in self.config
|
||||||
return True
|
|
||||||
except InvalidModelError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_model(self,
|
def create_key(self, model_name: str, model_type: SDModelType) -> str:
|
||||||
model_name: str,
|
return f"{model_type.name}/{model_name}"
|
||||||
model_type: SDModelType=None,
|
|
||||||
submodel: SDModelType=None,
|
def parse_key(self, model_key: str) -> Tuple[str, SDModelType]:
|
||||||
) -> SDModelInfo:
|
model_type_str, model_name = model_key.split('/', 1)
|
||||||
|
if model_type_str not in SDModelType.__members__:
|
||||||
|
# TODO:
|
||||||
|
raise Exception(f"Unkown model type: {model_type_str}")
|
||||||
|
|
||||||
|
return (model_name, SDModelType[model_type_str])
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType=None,
|
||||||
|
submodel: SDModelType=None,
|
||||||
|
) -> SDModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an SDModelInfo object describing it.
|
an SDModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
@ -254,85 +269,77 @@ class ModelManager(object):
|
|||||||
assume a diffusers pipeline. The behavior is illustrated here:
|
assume a diffusers pipeline. The behavior is illustrated here:
|
||||||
|
|
||||||
[models.yaml]
|
[models.yaml]
|
||||||
test1/diffusers:
|
diffusers/test1:
|
||||||
repo_id: foo/bar
|
repo_id: foo/bar
|
||||||
format: diffusers
|
|
||||||
description: Typical diffusers pipeline
|
description: Typical diffusers pipeline
|
||||||
|
|
||||||
test1/lora:
|
lora/test1:
|
||||||
repo_id: /tmp/loras/test1.safetensors
|
repo_id: /tmp/loras/test1.safetensors
|
||||||
format: lora
|
|
||||||
description: Typical lora file
|
description: Typical lora file
|
||||||
|
|
||||||
test1_pipeline = mgr.get_model('test1')
|
test1_pipeline = mgr.get_model('test1')
|
||||||
# returns a StableDiffusionGeneratorPipeline
|
# returns a StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
test1_vae1 = mgr.get_model('test1',submodel=SDModelType.vae)
|
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.vae)
|
||||||
# returns the VAE part of a diffusers model as an AutoencoderKL
|
# returns the VAE part of a diffusers model as an AutoencoderKL
|
||||||
|
|
||||||
test1_vae2 = mgr.get_model('test1',model_type=SDModelType.diffusers,submodel=SDModelType.vae)
|
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.diffusers, submodel=SDModelType.vae)
|
||||||
# does the same thing as the previous statement. Note that model_type
|
# does the same thing as the previous statement. Note that model_type
|
||||||
# is for the parent model, and submodel is for the part
|
# is for the parent model, and submodel is for the part
|
||||||
|
|
||||||
test1_lora = mgr.get_model('test1',model_type=SDModelType.lora)
|
test1_lora = mgr.get_model('test1', model_type=SDModelType.lora)
|
||||||
# returns a LoRA embed (as a 'dict' of tensors)
|
# returns a LoRA embed (as a 'dict' of tensors)
|
||||||
|
|
||||||
test1_encoder = mgr.get_modelI('test1',model_type=SDModelType.textencoder)
|
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.textencoder)
|
||||||
# raises an InvalidModelError
|
# raises an InvalidModelError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not model_name:
|
# TODO: delete default model or add check that this stable diffusion model
|
||||||
model_name = self.default_model()
|
# if not model_name:
|
||||||
|
# model_name = self.default_model()
|
||||||
|
|
||||||
model_key = self._disambiguate_name(model_name, model_type)
|
model_key = self.create_key(model_name, model_type)
|
||||||
|
if model_key not in self.config:
|
||||||
|
raise InvalidModelError(
|
||||||
|
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
||||||
|
)
|
||||||
|
|
||||||
# get the required loading info out of the config file
|
# get the required loading info out of the config file
|
||||||
mconfig = self.config[model_key]
|
mconfig = self.config[model_key]
|
||||||
|
|
||||||
format = mconfig.get('format','diffusers')
|
# type already checked as it's part of key
|
||||||
if model_type and model_type.name != format:
|
if model_type == SDModelType.diffusers:
|
||||||
raise InvalidModelError(
|
|
||||||
f'Inconsistent model definition; {model_key} has format {format}, but type {model_type.name} was requested'
|
|
||||||
)
|
|
||||||
|
|
||||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
|
||||||
|
|
||||||
if format == 'diffusers':
|
|
||||||
# intercept stanzas that point to checkpoint weights and replace them
|
# intercept stanzas that point to checkpoint weights and replace them
|
||||||
# with the equivalent diffusers model
|
# with the equivalent diffusers model
|
||||||
if 'weights' in mconfig:
|
if 'weights' in mconfig:
|
||||||
location = self.convert_ckpt_and_cache(mconfig)
|
location = self.convert_ckpt_and_cache(mconfig)
|
||||||
else:
|
else:
|
||||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||||
elif format in model_parts:
|
|
||||||
location = global_resolve_path(mconfig.get('path')) \
|
|
||||||
or mconfig.get('repo_id') \
|
|
||||||
or global_resolve_path(mconfig.get('weights'))
|
|
||||||
else:
|
else:
|
||||||
raise InvalidModelError(
|
location = global_resolve_path(
|
||||||
f'"{model_key}" has an unknown format {format}'
|
mconfig.get('path')) \
|
||||||
|
or mconfig.get('repo_id') \
|
||||||
|
or global_resolve_path(mconfig.get('weights')
|
||||||
)
|
)
|
||||||
|
|
||||||
model_type = model_parts[format]
|
|
||||||
subfolder = mconfig.get('subfolder')
|
subfolder = mconfig.get('subfolder')
|
||||||
revision = mconfig.get('revision')
|
revision = mconfig.get('revision')
|
||||||
hash = self.cache.model_hash(location,revision)
|
hash = self.cache.model_hash(location, revision)
|
||||||
|
|
||||||
# to support the traditional way of attaching a VAE
|
# to support the traditional way of attaching a VAE
|
||||||
# to a model, we hacked in `attach_model_part`
|
# to a model, we hacked in `attach_model_part`
|
||||||
vae = (None,None)
|
vae = (None, None)
|
||||||
try:
|
with suppress(Exception):
|
||||||
vae_id = mconfig.vae.repo_id
|
vae_id = mconfig.vae.repo_id
|
||||||
vae = (SDModelType.vae,vae_id)
|
vae = (SDModelType.vae, vae_id)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
location,
|
location,
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
revision = revision,
|
revision = revision,
|
||||||
subfolder = subfolder,
|
subfolder = subfolder,
|
||||||
submodel = submodel,
|
submodel = submodel,
|
||||||
attach_model_part=vae,
|
attach_model_part = vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
# in case we need to communicate information about this
|
# in case we need to communicate information about this
|
||||||
@ -402,27 +409,28 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models
|
||||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
|
||||||
'description': description,
|
|
||||||
'format': ('ckpt'|'diffusers'|'vae'),
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
Please use model_manager.models() to get all the model names,
|
Please use model_manager.models() to get all the model names,
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
named 'model-name', and model_manager.config to get the full OmegaConf
|
named 'model-name', and model_manager.config to get the full OmegaConf
|
||||||
object derived from models.yaml
|
object derived from models.yaml
|
||||||
"""
|
"""
|
||||||
models = {}
|
models = {}
|
||||||
for name in sorted(self.config, key=str.casefold):
|
for model_key in sorted(self.config, key=str.casefold):
|
||||||
stanza = self.config[name]
|
stanza = self.config[model_key]
|
||||||
|
|
||||||
# don't include VAEs in listing (legacy style)
|
# don't include VAEs in listing (legacy style)
|
||||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
if "config" in stanza and "/VAE/" in stanza["config"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
models[name] = dict()
|
model_name, model_type = self.parse_key(model_key)
|
||||||
format = stanza.get("format", "ckpt") # Determine Format
|
models[model_name] = dict()
|
||||||
|
|
||||||
|
# TODO: return all models in future
|
||||||
|
if model_type != SDModelType.diffusers:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_format = "ckpt" if "weights" in stanza else "diffusers"
|
||||||
|
|
||||||
# Common Attribs
|
# Common Attribs
|
||||||
status = self.cache.status(
|
status = self.cache.status(
|
||||||
@ -431,37 +439,38 @@ class ModelManager(object):
|
|||||||
subfolder=stanza.get('subfolder')
|
subfolder=stanza.get('subfolder')
|
||||||
)
|
)
|
||||||
description = stanza.get("description", None)
|
description = stanza.get("description", None)
|
||||||
models[name].update(
|
models[model_name].update(
|
||||||
description=description,
|
description=description,
|
||||||
format=format,
|
type=model_type,
|
||||||
|
format=model_format,
|
||||||
status=status.value
|
status=status.value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Checkpoint Config Parse
|
# Checkpoint Config Parse
|
||||||
if format == "ckpt":
|
if model_format == "ckpt":
|
||||||
models[name].update(
|
models[model_name].update(
|
||||||
config=str(stanza.get("config", None)),
|
config = str(stanza.get("config", None)),
|
||||||
weights=str(stanza.get("weights", None)),
|
weights = str(stanza.get("weights", None)),
|
||||||
vae=str(stanza.get("vae", None)),
|
vae = str(stanza.get("vae", None)),
|
||||||
width=str(stanza.get("width", 512)),
|
width = str(stanza.get("width", 512)),
|
||||||
height=str(stanza.get("height", 512)),
|
height = str(stanza.get("height", 512)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Diffusers Config Parse
|
# Diffusers Config Parse
|
||||||
if vae := stanza.get("vae", None):
|
elif model_format == "diffusers":
|
||||||
if isinstance(vae, DictConfig):
|
if vae := stanza.get("vae", None):
|
||||||
vae = dict(
|
if isinstance(vae, DictConfig):
|
||||||
repo_id=str(vae.get("repo_id", None)),
|
vae = dict(
|
||||||
path=str(vae.get("path", None)),
|
repo_id = str(vae.get("repo_id", None)),
|
||||||
subfolder=str(vae.get("subfolder", None)),
|
path = str(vae.get("path", None)),
|
||||||
)
|
subfolder = str(vae.get("subfolder", None)),
|
||||||
|
)
|
||||||
|
|
||||||
if format == "diffusers":
|
models[model_name].update(
|
||||||
models[name].update(
|
vae = vae,
|
||||||
vae=vae,
|
repo_id = str(stanza.get("repo_id", None)),
|
||||||
repo_id=str(stanza.get("repo_id", None)),
|
path = str(stanza.get("path", None)),
|
||||||
path=str(stanza.get("path", None)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
@ -472,44 +481,60 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
models = self.list_models()
|
models = self.list_models()
|
||||||
for name in models:
|
for name in models:
|
||||||
if models[name]["format"] == "vae":
|
if models[name]["type"] == "vae":
|
||||||
continue
|
continue
|
||||||
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}'
|
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
|
||||||
if models[name]["status"] == "active":
|
if models[name]["status"] == "active":
|
||||||
line = f"\033[1m{line}\033[0m"
|
line = f"\033[1m{line}\033[0m"
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def del_model(self, model_name: str, model_type: SDModelType.diffusers, delete_files: bool = False):
|
def del_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType.diffusers,
|
||||||
|
delete_files: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
"""
|
"""
|
||||||
model_name = self._disambiguate_name(model_name, model_type)
|
model_key = self.create_key(model_name, model_type)
|
||||||
omega = self.config
|
model_cfg = self.pop(model_key, None)
|
||||||
if model_name not in omega:
|
|
||||||
self.logger.error(f"Unknown model {model_name}")
|
if model_cfg is None:
|
||||||
return
|
self.logger.error(
|
||||||
# save these for use in deletion later
|
f"Unknown model {model_key}"
|
||||||
conf = omega[model_name]
|
)
|
||||||
repo_id = conf.get("repo_id", None)
|
return
|
||||||
path = self._abs_path(conf.get("path", None))
|
|
||||||
weights = self._abs_path(conf.get("weights", None))
|
# TODO: some legacy?
|
||||||
|
#if model_name in self.stack:
|
||||||
|
# self.stack.remove(model_name)
|
||||||
|
|
||||||
del omega[model_name]
|
|
||||||
if model_name in self.stack:
|
|
||||||
self.stack.remove(model_name)
|
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
repo_id = conf.get("repo_id", None)
|
||||||
|
path = self._abs_path(conf.get("path", None))
|
||||||
|
weights = self._abs_path(conf.get("weights", None))
|
||||||
|
if "weights" in model_cfg:
|
||||||
|
weights = self._abs_path(model_cfg["weights"])
|
||||||
self.logger.info(f"Deleting file {weights}")
|
self.logger.info(f"Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
|
||||||
|
elif "path" in model_cfg:
|
||||||
|
path = self._abs_path(model_cfg["path"])
|
||||||
self.logger.info(f"Deleting directory {path}")
|
self.logger.info(f"Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
|
||||||
|
elif "repo_id" in model_cfg:
|
||||||
|
repo_id = model_cfg["repo_id"]
|
||||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self, model_name: str, model_attributes: dict, clobber: bool = False
|
self,
|
||||||
|
model_name: str,
|
||||||
|
model_type: SDModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
clobber: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
@ -518,37 +543,47 @@ class ModelManager(object):
|
|||||||
method will return True. Will fail with an assertion error if provided
|
method will return True. Will fail with an assertion error if provided
|
||||||
attributes are incorrect or the model name is missing.
|
attributes are incorrect or the model name is missing.
|
||||||
"""
|
"""
|
||||||
omega = self.config
|
|
||||||
|
|
||||||
assert "format" in model_attributes, 'missing required field "format"'
|
if model_type == SDModelType.diffusers:
|
||||||
if model_attributes["format"] == "diffusers":
|
# TODO: automaticaly or manualy?
|
||||||
assert (
|
#assert "format" in model_attributes, 'missing required field "format"'
|
||||||
"description" in model_attributes
|
model_format = "ckpt" if "weights" in model_attributes else "diffusers"
|
||||||
), 'required field "description" is missing'
|
|
||||||
assert (
|
if model_format == "diffusers":
|
||||||
"path" in model_attributes or "repo_id" in model_attributes
|
assert (
|
||||||
), 'model must have either the "path" or "repo_id" fields defined'
|
"description" in model_attributes
|
||||||
elif model_attributes["format"] == "ckpt":
|
), 'required field "description" is missing'
|
||||||
for field in ("description", "weights", "height", "width", "config"):
|
assert (
|
||||||
assert field in model_attributes, f"required field {field} is missing"
|
"path" in model_attributes or "repo_id" in model_attributes
|
||||||
|
), 'model must have either the "path" or "repo_id" fields defined'
|
||||||
|
|
||||||
|
elif model_format == "ckpt":
|
||||||
|
for field in ("description", "weights", "height", "width", "config"):
|
||||||
|
assert field in model_attributes, f"required field {field} is missing"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert "weights" in model_attributes and "description" in model_attributes
|
assert "weights" in model_attributes and "description" in model_attributes
|
||||||
|
|
||||||
model_key = f'{model_name}/{model_attributes["format"]}'
|
model_key = self.create_key(model_name, model_type)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
clobber or model_key not in omega
|
clobber or model_key not in self.config
|
||||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
), f'attempt to overwrite existing model definition "{model_key}"'
|
||||||
|
|
||||||
omega[model_key] = model_attributes
|
self.config[model_key] = model_attributes
|
||||||
|
|
||||||
if "weights" in omega[model_key]:
|
if "weights" in self.config[model_key]:
|
||||||
omega[model_key]["weights"].replace("\\", "/")
|
self.config[model_key]["weights"].replace("\\", "/")
|
||||||
|
|
||||||
if clobber and model_key in self.cache_keys:
|
if clobber and model_key in self.cache_keys:
|
||||||
self.cache.uncache_model(self.cache_keys[model_key])
|
self.cache.uncache_model(self.cache_keys[model_key])
|
||||||
del self.cache_keys[model_key]
|
del self.cache_keys[model_key]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
repo_or_path: Union[str, Path],
|
repo_or_path: Union[str, Path],
|
||||||
@ -587,10 +622,10 @@ class ModelManager(object):
|
|||||||
return model_key
|
return model_key
|
||||||
|
|
||||||
def import_lora(
|
def import_lora(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: str=None,
|
||||||
description: str=None,
|
description: str=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated lora file. Call
|
Creates an entry for the indicated lora file. Call
|
||||||
@ -599,20 +634,21 @@ class ModelManager(object):
|
|||||||
path = Path(path)
|
path = Path(path)
|
||||||
model_name = model_name or path.stem
|
model_name = model_name or path.stem
|
||||||
model_description = description or f"LoRA model {model_name}"
|
model_description = description or f"LoRA model {model_name}"
|
||||||
self.add_model(f'{model_name}/{SDModelType.lora.name}',
|
self.add_model(
|
||||||
dict(
|
f'{model_name}/{SDModelType.lora.name}',
|
||||||
format="lora",
|
dict(
|
||||||
weights=str(path),
|
format="lora",
|
||||||
description=model_description,
|
weights=str(path),
|
||||||
),
|
description=model_description,
|
||||||
True
|
),
|
||||||
)
|
True
|
||||||
|
)
|
||||||
|
|
||||||
def import_embedding(
|
def import_embedding(
|
||||||
self,
|
self,
|
||||||
path: Path,
|
path: Path,
|
||||||
model_name: str=None,
|
model_name: str=None,
|
||||||
description: str=None,
|
description: str=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an entry for the indicated lora file. Call
|
Creates an entry for the indicated lora file. Call
|
||||||
@ -626,14 +662,15 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_name = model_name or path.stem
|
model_name = model_name or path.stem
|
||||||
model_description = description or f"Textual embedding model {model_name}"
|
model_description = description or f"Textual embedding model {model_name}"
|
||||||
self.add_model(f'{model_name}/{SDModelType.textual_inversion.name}',
|
self.add_model(
|
||||||
dict(
|
f'{model_name}/{SDModelType.textual_inversion.name}',
|
||||||
format="textual_inversion",
|
dict(
|
||||||
weights=str(weights),
|
format="textual_inversion",
|
||||||
description=model_description,
|
weights=str(weights),
|
||||||
),
|
description=model_description,
|
||||||
True
|
),
|
||||||
)
|
True
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||||
@ -857,7 +894,7 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def convert_ckpt_and_cache(self, mconfig:DictConfig)->Path:
|
def convert_ckpt_and_cache(self, mconfig: DictConfig)->Path:
|
||||||
"""
|
"""
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
diffusers, cache it to disk, and return Path to converted
|
diffusers, cache it to disk, and return Path to converted
|
||||||
@ -872,6 +909,7 @@ class ModelManager(object):
|
|||||||
return diffusers_path
|
return diffusers_path
|
||||||
|
|
||||||
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||||
|
|
||||||
# to avoid circular import errors
|
# to avoid circular import errors
|
||||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
@ -881,15 +919,16 @@ class ModelManager(object):
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
vae_path=str(global_resolve_path(vae_ckpt_path)),
|
vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
)
|
)
|
||||||
return diffusers_path
|
return diffusers_path
|
||||||
|
|
||||||
def _get_vae_for_conversion(self,
|
def _get_vae_for_conversion(
|
||||||
weights: Path,
|
self,
|
||||||
mconfig: DictConfig
|
weights: Path,
|
||||||
)->tuple(Path,SDModelType.vae):
|
mconfig: DictConfig
|
||||||
|
) -> Tuple[Path, SDModelType.vae]:
|
||||||
# VAE handling is convoluted
|
# VAE handling is convoluted
|
||||||
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
# 1. If there is a .vae.ckpt file sharing same stem as weights, then use
|
||||||
# it as the vae_path passed to convert
|
# it as the vae_path passed to convert
|
||||||
@ -1047,21 +1086,7 @@ class ModelManager(object):
|
|||||||
# model requires a model config file, a weights file,
|
# model requires a model config file, a weights file,
|
||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
"""
|
"""
|
||||||
)
|
|
||||||
|
|
||||||
def _disambiguate_name(self, model_name:str, model_type:SDModelType)->str:
|
|
||||||
model_type = model_type or SDModelType.diffusers
|
|
||||||
full_name = f"{model_name}/{model_type.name}"
|
|
||||||
if full_name in self.config:
|
|
||||||
return full_name
|
|
||||||
# special case - if diffusers requested, then allow name without type appended
|
|
||||||
if model_type==SDModelType.diffusers \
|
|
||||||
and model_name in self.config \
|
|
||||||
and self.config[model_name].format=='diffusers':
|
|
||||||
return model_name
|
|
||||||
raise InvalidModelError(
|
|
||||||
f'"{full_name}" is not a known model name. Please check your models.yaml file'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user