This commit is contained in:
Lincoln Stein
2023-09-14 21:12:41 -04:00
parent 716a1b6423
commit a033ccc776
9 changed files with 116 additions and 135 deletions

View File

@ -18,7 +18,7 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
UnknownModelException,
DuplicateModelException
DuplicateModelException,
)
from invokeai.backend.model_manager.cache import CacheStats
from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any
@ -51,10 +51,10 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model identified by key.
@ -71,8 +71,8 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def model_exists(
self,
key: str,
self,
key: str,
) -> bool:
pass
@ -85,10 +85,11 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def list_models(self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
def list_models(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return a list of ModelConfigBases that match the base, type and name criteria.
@ -104,13 +105,11 @@ class ModelManagerServiceBase(ABC):
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.list_models(
model_name=model_name,
base_model=base_model,
model_type=model_type
)
model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException("More than one model share the same name and type: {base_model}/{model_type}/{model_name}")
raise DuplicateModelException(
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
)
if len(model_configs) == 0:
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
return model_configs[0]
@ -123,22 +122,19 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def add_model(
self,
model_path: Path,
probe_overrides: Optional[Dict[str, Any]] = None,
wait: bool = False
self, model_path: Path, probe_overrides: Optional[Dict[str, Any]] = None, wait: bool = False
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
assertion error if the name already exists.
"""
pass
@abstractmethod
def update_model(
self,
key: str,
new_config: Union[dict, ModelConfigBase],
self,
key: str,
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
@ -151,11 +147,7 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def del_model(
self,
key: str,
delete_files: bool = False
):
def del_model(self, key: str, delete_files: bool = False):
"""
Delete the named model from configuration. If delete_files
is true, then the underlying file or directory will be
@ -164,9 +156,9 @@ class ModelManagerServiceBase(ABC):
pass
def rename_model(
self,
key: str,
new_name: str,
self,
key: str,
new_name: str,
) -> ModelConfigBase:
"""
Rename the indicated model.
@ -182,14 +174,14 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def convert_model(
self,
key: str,
convert_dest_directory: Path,
self,
key: str,
convert_dest_directory: Path,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
This will delete the cached version if there is any and delete the original
This will delete the cached version if there is any and delete the original
checkpoint file if it is in the models directory.
:param key: Unique key for the model to convert.
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
@ -201,10 +193,10 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def install_model (
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
@ -270,7 +262,7 @@ class ModelManagerServiceBase(ABC):
pass
# implementation
# implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
@ -289,18 +281,18 @@ class ModelManagerService(ModelManagerServiceBase):
self._loader = ModelLoader(config)
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
@ -313,8 +305,8 @@ class ModelManagerService(ModelManagerServiceBase):
return model_info
def model_exists(
self,
key: str,
self,
key: str,
) -> bool:
"""
Given a model key, returns True if it is a valid
@ -331,10 +323,11 @@ class ModelManagerService(ModelManagerServiceBase):
# def all_models(self) -> List[ModelConfigBase] -- defined in base class, same as list_models()
# def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -- defined in base class
def list_models(self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
def list_models(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return a ModelConfigBase object for each model in the database.
@ -347,14 +340,11 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
def add_model(
self,
model_path: Path,
model_attributes: Optional[dict] = None,
wait: bool = False
self, model_path: Path, model_attributes: Optional[dict] = None, wait: bool = False
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
assertion error if the name already exists.
"""
self.logger.debug(f"add/update model {model_path}")
return self._loader.installer.install(
@ -363,16 +353,16 @@ class ModelManagerService(ModelManagerServiceBase):
)
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
assertion error if the name already exists.
"""
self.logger.debug(f"add/update model {source}")
variant = 'fp16' if self._loader.precision == 'float16' else None
variant = "fp16" if self._loader.precision == "float16" else None
return self._loader.installer.install(
source,
probe_override=model_attributes,
@ -380,9 +370,9 @@ class ModelManagerService(ModelManagerServiceBase):
)
def update_model(
self,
key: str,
new_config: Union[dict, ModelConfigBase],
self,
key: str,
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
@ -398,9 +388,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self._loader.store.update_model(key, new_config)
def del_model(
self,
key: str,
delete_files: bool = False,
self,
key: str,
delete_files: bool = False,
):
"""
Delete the named model from configuration. If delete_files is true,
@ -418,9 +408,9 @@ class ModelManagerService(ModelManagerServiceBase):
path.unlink()
def convert_model(
self,
key: str,
convert_dest_directory: Path,
self,
key: str,
convert_dest_directory: Path,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@ -436,7 +426,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
model_info = self.model_info(key)
self.logger.debug(f"convert model {model_info.name}")
self.logger.warning('This is not implemented yet')
self.logger.warning("This is not implemented yet")
return self._loader.convert_model(key, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
@ -478,17 +468,17 @@ class ModelManagerService(ModelManagerServiceBase):
@property
def logger(self):
return self._loader.logger
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@ -500,7 +490,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
merger = ModelMerger(self.mgr)
try:
self.logger.error('ModelMerger needs to be rewritten.')
self.logger.error("ModelMerger needs to be rewritten.")
result = merger.merge_diffusion_models_and_save(
model_keys=model_keys,
merged_model_name=merged_model_name,

View File

@ -172,11 +172,13 @@ class VaeDiffusersConfig(ModelConfigBase):
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""

View File

@ -344,14 +344,12 @@ class DownloadQueue(DownloadQueueBase):
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url \
or resp["images"][0]["url"]
metadata.description = metadata.description \
or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
if resp["trainedWords"]
else resp["description"]
)
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
if resp["trainedWords"]
else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
# a Civitai model page
@ -364,10 +362,11 @@ class DownloadQueue(DownloadQueueBase):
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.thumbnail_url = metadata.thumbnail_url \
or resp["modelVersions"][0]["images"][0]["url"]
metadata.license = metadata.license \
metadata.thumbnail_url = metadata.thumbnail_url or resp["modelVersions"][0]["images"][0]["url"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
self._logger.warn(excp)

View File

@ -316,9 +316,8 @@ class ModelInstall(ModelInstallBase):
"""Return the queue."""
return self._download_queue
def register_path(self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None
def register_path(
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
@ -354,13 +353,13 @@ class ModelInstall(ModelInstallBase):
return key
def install_path(
self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None,
self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
@ -375,14 +374,11 @@ class ModelInstall(ModelInstallBase):
info,
)
def _probe_model(self,
model_path: Union[Path, str],
overrides: Optional[Dict[str,Any]] = None
) -> ModelProbeInfo:
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(model_path)
if overrides: # used to override probe fields
for key, value in overrides.items():
setattr(info, key, value) # may generate a pydantic validation error
setattr(info, key, value) # may generate a pydantic validation error
return info
def unregister(self, key: str): # noqa D102
@ -394,12 +390,12 @@ class ModelInstall(ModelInstallBase):
self.unregister(key)
def install(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
variant: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
variant: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102
queue = self._download_queue

View File

@ -75,13 +75,10 @@ class ModelLoaderBase(ABC):
"""Return the current logger."""
pass
@abstractmethod
def collect_cache_stats(
self,
cache_stats: CacheStats
):
def collect_cache_stats(self, cache_stats: CacheStats):
"""Replace cache statistics."""
pass
@property
@ -98,6 +95,7 @@ class ModelLoaderBase(ABC):
"""
pass
class ModelLoader(ModelLoaderBase):
"""Implementation of ModelLoaderBase."""
@ -228,10 +226,7 @@ class ModelLoader(ModelLoaderBase):
_cache=self._cache,
)
def collect_cache_stats(
self,
cache_stats: CacheStats
):
def collect_cache_stats(self, cache_stats: CacheStats):
self._cache.stats = cache_stats
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
@ -275,7 +270,6 @@ class ModelLoader(ModelLoaderBase):
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk.")
for model_config in self._store.all_models():
path = self._resolve_model_path(model_config.path)

View File

@ -128,7 +128,7 @@ class ModelProbe(ModelProbeBase):
format_type = "onnx" if model_type == ModelType.ONNX else "diffusers" if model.is_dir() else "checkpoint"
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
@ -138,7 +138,7 @@ class ModelProbe(ModelProbeBase):
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,

View File

@ -11,6 +11,7 @@ from ..config import ModelConfigBase, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.1.1"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
@ -122,4 +123,3 @@ class ModelConfigStore(ABC):
Return all the model configs in the database.
"""
return self.search_by_name()

View File

@ -89,8 +89,9 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._conn.commit()
finally:
self._lock.release()
assert self.version == CONFIG_FILE_VERSION, \
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
assert (
self.version == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
@ -182,9 +183,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
)
VALUES (?,?);
""",
("version",CONFIG_FILE_VERSION),
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
@ -256,7 +256,6 @@ class ModelConfigStoreSQL(ModelConfigStore):
finally:
self._lock.release()
def _update_tags(self, key: str, tags: List[str]) -> None:
"""Update tags for model with key."""
# remove previous tags from this model

View File

@ -78,8 +78,9 @@ class ModelConfigStoreYAML(ModelConfigStore):
if not self._filename.exists():
self._initialize_yaml()
self._config = OmegaConf.load(self._filename)
assert self.version == CONFIG_FILE_VERSION, \
f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
assert (
self.version == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _initialize_yaml(self):
try:
@ -104,7 +105,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
@property
def version(self) -> str:
"""Return version of this config file/database."""
return self._config["__metadata__"].get('version')
return self._config["__metadata__"].get("version")
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""