pytests mostly working; model_manager_service needs rewriting

This commit is contained in:
Lincoln Stein
2023-09-11 23:47:24 -04:00
parent 7430d87301
commit 6d8b2a7385
26 changed files with 187 additions and 129 deletions

View File

@ -12,9 +12,8 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
SDXLConditioningInfo,
)
from ...backend.model_management.models import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_manager import ModelType, UnknownModelException
from ...backend.model_manager.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
@ -94,7 +93,7 @@ class CompelInvocation(BaseInvocation):
).context.model,
)
)
except ModelNotFoundException:
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
@ -208,7 +207,7 @@ class SDXLPromptInvocationBase:
).context.model,
)
)
except ModelNotFoundException:
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())

View File

@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType
from ...backend.model_manager import BaseModelType
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,

View File

@ -31,11 +31,10 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.model_manager import BaseModelType, ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_seamless
from ...backend.model_management.models import BaseModelType
from ...backend.model_manager.lora import ModelPatcher
from ...backend.model_manager.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_manager import BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,

View File

@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.model_manager import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin

View File

@ -1,4 +1,4 @@
from ...backend.model_management import ModelType, SubModelType
from ...backend.model_manager import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,

View File

@ -43,7 +43,7 @@ from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_manager.cache import CacheStats
# size of GIG in bytes
GIG = 1073741824

View File

@ -9,20 +9,21 @@ from pydantic import Field
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management import (
ModelManager,
from invokeai.backend.model_manager import (
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
AddModelResult,
SchedulerPredictionType,
ModelMerger,
DownloadJobBase,
MergeInterpolationMethod,
ModelNotFoundException,
ModelConfigBase,
ModelInfo,
ModelLoader,
ModelMerger,
ModelType,
SchedulerPredictionType,
SubModelType,
UnknownModelException,
)
from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.cache import CacheStats
import torch
from invokeai.app.models.exceptions import CanceledException
@ -128,7 +129,7 @@ class ModelManagerServiceBase(ABC):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
) -> InstallJobBase:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -145,10 +146,10 @@ class ModelManagerServiceBase(ABC):
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException if the name does not already exist.
UnknownModelException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
@ -196,7 +197,7 @@ class ModelManagerServiceBase(ABC):
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
) -> AddModelResult:
) -> InstallJobBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
@ -216,7 +217,7 @@ class ModelManagerServiceBase(ABC):
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
) -> InstallJobBase:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
@ -249,7 +250,7 @@ class ModelManagerServiceBase(ABC):
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> AddModelResult:
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
@ -438,7 +439,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
) -> InstallJobBase:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -455,17 +456,17 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
) -> InstallJobBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException exception if the name does not already exist.
UnknownModelException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"update model {model_name}")
if not self.model_exists(model_name, base_model, model_type):
raise ModelNotFoundException(f"Unknown model {model_name}")
raise UnknownModelException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
@ -491,7 +492,7 @@ class ModelManagerService(ModelManagerServiceBase):
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
) -> InstallJobBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
@ -560,7 +561,7 @@ class ModelManagerService(ModelManagerServiceBase):
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
) -> dict[str, InstallJobBase]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
@ -594,7 +595,7 @@ class ModelManagerService(ModelManagerServiceBase):
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
) -> str:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge

View File

@ -5,7 +5,7 @@ from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.model_management.models import BaseModelType
from ...backend.model_manager import BaseModelType
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):

View File

@ -3,11 +3,14 @@ Initialization file for invokeai.backend
"""
from .model_manager import ( # noqa F401
ModelLoader,
ModelInstall,
ModelConfigStore,
SilenceWarnings,
DuplicateModelException,
InvalidModelException,
BaseModelType,
ModelType,
SubModelType,
SchedulerPredictionType,
ModelVariantType,
)

View File

@ -14,8 +14,16 @@ from .config import ( # noqa F401
SubModelType,
SilenceWarnings,
)
from .loader import ModelLoader # noqa F401
from .install import ModelInstall # noqa F401
from .lora import ONNXModelPatcher, ModelPatcher
from .loader import ModelLoader, ModelInfo # noqa F401
from .install import ModelInstall, DownloadJobBase # noqa F401
from .probe import ModelProbe, InvalidModelException # noqa F401
from .storage import DuplicateModelException # noqa F401
from .storage import (
UnknownModelException,
DuplicateModelException,
ModelConfigStore,
ModelConfigStoreYAML,
ModelConfigStoreSQL,
) # noqa F401
from .search import ModelSearch # noqa F401
from .merge import MergeInterpolationMethod, ModelMerger

View File

@ -114,7 +114,7 @@ class ModelConfigBase(BaseModel):
base_model: BaseModelType
model_type: ModelType
model_format: ModelFormat
id: Optional[str] = Field(None) # this may get added by the store
key: Optional[str] = Field(None) # this will get added by the store
description: Optional[str] = Field(None)
author: Optional[str] = Field(description="Model author")
license: Optional[str] = Field(description="License string")
@ -244,6 +244,7 @@ class ModelConfigFactory(object):
def make_config(
cls,
model_data: Union[dict, ModelConfigBase],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> Union[
MainCheckpointConfig,
@ -263,6 +264,8 @@ class ModelConfigFactory(object):
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
if key:
model_data.key = key
return model_data
try:
model_format = model_data.get("model_format")
@ -271,7 +274,10 @@ class ModelConfigFactory(object):
class_to_return = dest_class or cls._class_map[model_format][model_type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
return class_to_return.parse_obj(model_data)
model = class_to_return.parse_obj(model_data)
if key:
model.key = key # ensure consistency
return model
except KeyError as exc:
raise InvalidModelConfigException(
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"

View File

@ -49,6 +49,9 @@ class DownloadJobBase(BaseModel):
id: int = Field(description="Numeric ID of this job")
source: str = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")

View File

@ -20,7 +20,7 @@ Typical usage:
# register config, and install model in `models`
id: str = installer.install_path('/path/to/model')
# download some remote models and install them in the background
1 # download some remote models and install them in the background
installer.install('stabilityai/stable-diffusion-2-1')
installer.install('https://civitai.com/api/download/models/154208')
installer.install('runwayml/stable-diffusion-v1-5')
@ -58,7 +58,7 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .search import ModelSearch
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
from .storage import ModelConfigStore, DuplicateModelException, get_config_store
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
from .hash import FastModelHash
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
@ -272,7 +272,7 @@ class ModelInstall(ModelInstallBase):
): # noqa D107 - use base class docstrings
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
self._store = store or get_config_store(self._config.model_conf_path)
self._download_queue = download or DownloadQueue(config=self._config)
self._async_installs = dict()
self._installed = set()
@ -289,7 +289,7 @@ class ModelInstall(ModelInstallBase):
return self._register(model_path, info)
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
id: str = FastModelHash.hash(model_path)
key: str = FastModelHash.hash(model_path)
registration_data = dict(
path=model_path.as_posix(),
name=model_path.stem,
@ -309,13 +309,13 @@ class ModelInstall(ModelInstallBase):
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
)
except KeyError as exc:
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
registration_data.update(
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
)
self._store.add_model(id, registration_data)
return id
self._store.add_model(key, registration_data)
return key
def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102
model_path = Path(model_path)
@ -334,13 +334,13 @@ class ModelInstall(ModelInstallBase):
info,
)
def unregister(self, id: str): # noqa D102
self._store.del_model(id)
def unregister(self, key: str): # noqa D102
self._store.del_model(key)
def delete(self, id: str): # noqa D102
model = self._store.get_model(id)
def delete(self, key: str): # noqa D102
model = self._store.get_model(key)
rmtree(model.path)
self.unregister(id)
self.unregister(key)
def install(
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
@ -381,6 +381,7 @@ class ModelInstall(ModelInstallBase):
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
@ -421,8 +422,8 @@ class ModelInstall(ModelInstallBase):
for model in self._store.all_models():
path = Path(model.path)
if not path.exists():
self._store.del_model(model.id)
unregistered.append(model.id)
self._store.del_model(model.key)
unregistered.append(model.key)
return unregistered
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102

View File

@ -26,7 +26,7 @@ class ModelInfo:
name: str
base_model: BaseModelType
type: ModelType
id: str
key: str
location: Union[Path, str]
precision: torch.dtype
_cache: Optional[ModelCache] = None
@ -186,7 +186,7 @@ class ModelLoader(ModelLoaderBase):
name=model_config.name,
base_model=model_config.base_model,
type=submodel_type or model_type,
id=model_config.id,
key=model_config.key,
location=model_path,
precision=self._cache.precision,
_cache=self._cache,

View File

@ -1,5 +1,5 @@
"""
invokeai.backend.model_management.model_merge exports:
invokeai.backend.model_manager.merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
@ -15,7 +15,7 @@ from typing import List, Union, Optional
import invokeai.backend.util.logging as logger
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from . import ModelLoader, ModelType, BaseModelType, ModelVariantType, ModelConfigBase
class MergeInterpolationMethod(str, Enum):
@ -26,7 +26,7 @@ class MergeInterpolationMethod(str, Enum):
class ModelMerger(object):
def __init__(self, manager: ModelManager):
def __init__(self, manager: ModelLoader):
self.manager = manager
def merge_diffusion_models(
@ -77,7 +77,7 @@ class ModelMerger(object):
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
) -> ModelConfigBase:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)

View File

@ -1,6 +1,19 @@
"""
Initialization file for invokeai.backend.model_manager.storage
"""
import pathlib
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
from .yaml import ModelConfigStoreYAML # noqa F401
from .sql import ModelConfigStoreSQL # noqa F401
def get_config_store(location: pathlib.Path) -> ModelConfigStore:
"""Return the type of ModelConfigStore appropriate to the path."""
location = pathlib.Path(location)
if location.suffix == ".yaml":
return ModelConfigStoreYAML(location)
elif location.suffix == ".db":
return ModelConfigStoreSQL(location)
else:
raise Exception("Unable to determine type of configuration file '{location}'")

View File

@ -19,7 +19,7 @@ class InvalidModelException(Exception):
class UnknownModelException(Exception):
"""Raised on an attempt to delete a model with a nonexistent key."""
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ModelConfigStore(ABC):
@ -90,7 +90,7 @@ class ModelConfigStore(ABC):
pass
@abstractmethod
def search_by_type(
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
@ -112,4 +112,4 @@ class ModelConfigStore(ABC):
"""
Return all the model configs in the database.
"""
return self.search_by_type()
return self.search_by_name()

View File

@ -16,7 +16,7 @@ Typical usage:
tags=['sfw','cartoon']
)
# adding - the key becomes the model's "id" field
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
@ -30,14 +30,14 @@ Typical usage:
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.id == 'key1'
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_type(base_model='sd-2', model_type='main')
configs = store.search_by_name(base_model='sd-2', model_type='main')
"""
import threading
@ -173,8 +173,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
Can raise DuplicateModelException and InvalidModelConfig exceptions.
"""
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect.
record.id = key # add the unique storage key to object
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
try:
self._lock.acquire()
@ -293,7 +292,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
try:
self._lock.acquire()
@ -309,7 +308,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
""",
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount < 1:
if self._cursor.rowcount == 0:
raise UnknownModelException
if record.tags:
self._update_tags(key, record.tags)
@ -404,7 +403,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._lock.release()
return results
def search_by_type(
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,

View File

@ -16,7 +16,7 @@ Typical usage:
tags=['sfw','cartoon']
)
# adding - the key becomes the model's "id" field
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
@ -30,18 +30,19 @@ Typical usage:
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.id == 'key1'
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_type(base_model='sd-2', model_type='main')
configs = store.search_by_name(base_model='sd-2', model_type='main')
"""
import threading
import yaml
from enum import Enum
from pathlib import Path
from typing import Union, Set, List, Optional
from omegaconf import OmegaConf
@ -110,8 +111,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
Can raise DuplicateModelException and InvalidModelConfig exceptions.
"""
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect
record.id = key # add the key used to store the object
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
try:
self._lock.acquire()
@ -120,11 +120,18 @@ class ModelConfigStoreYAML(ModelConfigStore):
raise DuplicateModelException(
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
)
self._config[key] = dict_fields
self._config[key] = self._fix_enums(dict_fields)
self._commit()
finally:
self._lock.release()
def _fix_enums(self, original: dict) -> dict:
"""In python 3.9, omegaconf stores incorrectly stringified enums"""
fixed_dict = {}
for key, value in original.items():
fixed_dict[key] = value.value if isinstance(value, Enum) else value
return fixed_dict
def del_model(self, key: str) -> None:
"""
Delete a model.
@ -150,13 +157,13 @@ class ModelConfigStoreYAML(ModelConfigStore):
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
try:
self._lock.acquire()
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config[key] = dict_fields
self._config[key] = self._fix_enums(dict_fields)
self._commit()
finally:
self._lock.release()
@ -171,7 +178,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""
try:
record = self._config[key]
return ModelConfigFactory.make_config(record)
return ModelConfigFactory.make_config(record, key)
except KeyError as e:
raise UnknownModelException(f"Unknown key '{key}' for model config") from e
@ -202,7 +209,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
self._lock.release()
return results
def search_by_type(
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
@ -224,7 +231,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record)
model = ModelConfigFactory.make_config(record, key)
if model_name and model.name != model_name:
continue
if base_model and model.base_model != base_model:

View File

@ -3,45 +3,47 @@ from pathlib import Path
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
from invokeai.backend import ModelConfigStore, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import ModelLoader
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
BASIC_MODEL_NAME = "sdxl-base-1-0"
VAE_OVERRIDE_MODEL_NAME = "sdxl-base-with-custom-vae-1-0"
VAE_NULL_OVERRIDE_MODEL_NAME = "sdxl-base-with-empty-vae-1-0"
@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
def model_manager(datadir) -> ModelLoader:
config = InvokeAIAppConfig(root=datadir, conf_path="configs/relative_sub.models.yaml")
return ModelLoader(config=config)
def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
def test_get_model_names(model_manager: ModelLoader):
store = model_manager.store
names = [x.name for x in store.all_models()]
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path):
models = model_manager.store.search_by_name(model_name=BASIC_MODEL_NAME)
assert len(models) == 1
model_config = models[0]
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir: Path):
models = model_manager.store.search_by_name(model_name=VAE_OVERRIDE_MODEL_NAME)
assert len(models) == 1
model_config = models[0]
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
)
def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoader, datadir: Path):
model_config = model_manager.store.search_by_name(model_name=VAE_NULL_OVERRIDE_MODEL_NAME)[0]
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
assert not is_override

View File

@ -1,22 +1,30 @@
__metadata__:
version: 3.0.0
sdxl/main/SDXL base:
version: 3.1.0
ed799245c762f6d0a9ddfd4e31fdb010:
name: sdxl-base-1-0
path: sdxl/main/SDXL base 1_0
base_model: sdxl
model_type: main
model_format: diffusers
model_variant: normal
description: SDXL base v1.0
variant: normal
format: diffusers
sdxl/main/SDXL with VAE:
fa78e05dbf51c540ff9256eb65446fd6:
name: sdxl-base-with-custom-vae-1-0
path: sdxl/main/SDXL base 1_0
base_model: sdxl
model_type: main
model_variant: normal
model_format: diffusers
description: SDXL with customized VAE
vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal
format: diffusers
sdxl/main/SDXL with empty VAE:
8a79e05d9f51c5ffff9256eb65446fd6:
name: sdxl-base-with-empty-vae-1-0
path: sdxl/main/SDXL base 1_0
base_model: sdxl
model_type: main
model_variant: normal
model_format: diffusers
description: SDXL with customized VAE
vae: ''
variant: normal
format: diffusers

View File

@ -12,6 +12,7 @@ from invokeai.backend.model_manager.storage import (
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
ModelType,
TextualInversionConfig,
DiffusersConfig,
VaeDiffusersConfig,
@ -113,14 +114,15 @@ def test_filter(store: ModelConfigStore):
config3 = VaeDiffusersConfig(path="/tmp/config3", name="config3", base_model="sd-1", model_type="vae", tags=["sfw"])
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_type(model_type="main")
matches = store.search_by_name(model_type="main")
assert len(matches) == 2
assert matches[0].name in {"config1", "config2"}
matches = store.search_by_type(model_type="vae")
matches = store.search_by_name(model_type="vae")
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].id == sha256("config3".encode("utf-8")).hexdigest()
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance(matches[0].model_type, ModelType) # This tests that we get proper enums back
matches = store.search_by_tag(["sfw"])
assert len(matches) == 3

View File

@ -3,6 +3,7 @@ Test the refactored model config classes.
"""
import pytest
import sys
from hashlib import sha256
from invokeai.app.services.config import InvokeAIAppConfig
@ -12,6 +13,7 @@ from invokeai.backend.model_manager.storage import (
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
ModelType,
TextualInversionConfig,
DiffusersConfig,
VaeDiffusersConfig,
@ -20,6 +22,7 @@ from invokeai.backend.model_manager.config import (
@pytest.fixture
def store(datadir) -> ModelConfigStore:
print(f"DEBUG: datadir={datadir}")
InvokeAIAppConfig.get_config(root=datadir)
return ModelConfigStoreSQL(datadir / "databases" / "models.db")
@ -89,11 +92,14 @@ def test_delete(store: ModelConfigStore):
except UnknownModelException:
assert True
try:
store.del_model("unknown")
assert False, "expected delete of unknown model to raise exception"
except UnknownModelException:
assert True
# a bug in sqlite3 in python 3.9 prevents DEL from returning number of
# deleted rows!
if sys.version_info.major == 3 and sys.version_info.minor > 9:
try:
store.del_model("unknown")
assert False, "expected delete of unknown model to raise exception"
except UnknownModelException:
assert True
def test_exists(store: ModelConfigStore):
@ -113,14 +119,15 @@ def test_filter(store: ModelConfigStore):
config3 = VaeDiffusersConfig(path="/tmp/config3", name="config3", base_model="sd-1", model_type="vae", tags=["sfw"])
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_type(model_type="main")
matches = store.search_by_name(model_type="main")
assert len(matches) == 2
assert matches[0].name in {"config1", "config2"}
matches = store.search_by_type(model_type="vae")
matches = store.search_by_name(model_type="vae")
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].id == sha256("config3".encode("utf-8")).hexdigest()
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
assert isinstance(matches[0].model_type, ModelType) # This tests that we get proper enums back
matches = store.search_by_tag(["sfw"])
assert len(matches) == 3