mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
pytests mostly working; model_manager_service needs rewriting
This commit is contained in:
@ -12,9 +12,8 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.models import ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.model_manager import ModelType, UnknownModelException
|
||||
from ...backend.model_manager.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
@ -94,7 +93,7 @@ class CompelInvocation(BaseInvocation):
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
@ -208,7 +207,7 @@ class SDXLPromptInvocationBase:
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
|
@ -29,7 +29,7 @@ from pydantic import BaseModel, Field, validator
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
|
@ -31,11 +31,10 @@ from invokeai.app.invocations.primitives import (
|
||||
)
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_manager.lora import ModelPatcher
|
||||
from ...backend.model_manager.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
|
@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.model_manager import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
|
@ -1,4 +1,4 @@
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from ...backend.model_manager import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
|
@ -43,7 +43,7 @@ from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
@ -9,20 +9,21 @@ from pydantic import Field
|
||||
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management import (
|
||||
ModelManager,
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
AddModelResult,
|
||||
SchedulerPredictionType,
|
||||
ModelMerger,
|
||||
DownloadJobBase,
|
||||
MergeInterpolationMethod,
|
||||
ModelNotFoundException,
|
||||
ModelConfigBase,
|
||||
ModelInfo,
|
||||
ModelLoader,
|
||||
ModelMerger,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -128,7 +129,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
) -> InstallJobBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -145,10 +146,10 @@ class ModelManagerServiceBase(ABC):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
UnknownModelException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
@ -196,7 +197,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
) -> InstallJobBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
@ -216,7 +217,7 @@ class ModelManagerServiceBase(ABC):
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
) -> InstallJobBase:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -249,7 +250,7 @@ class ModelManagerServiceBase(ABC):
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
@ -438,7 +439,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
) -> InstallJobBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -455,17 +456,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
) -> InstallJobBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
UnknownModelException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
raise UnknownModelException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
@ -491,7 +492,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
) -> InstallJobBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
@ -560,7 +561,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
) -> dict[str, InstallJobBase]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -594,7 +595,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
) -> str:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
|
@ -5,7 +5,7 @@ from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
|
@ -3,11 +3,14 @@ Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_manager import ( # noqa F401
|
||||
ModelLoader,
|
||||
ModelInstall,
|
||||
ModelConfigStore,
|
||||
SilenceWarnings,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
SchedulerPredictionType,
|
||||
ModelVariantType,
|
||||
)
|
||||
|
@ -14,8 +14,16 @@ from .config import ( # noqa F401
|
||||
SubModelType,
|
||||
SilenceWarnings,
|
||||
)
|
||||
from .loader import ModelLoader # noqa F401
|
||||
from .install import ModelInstall # noqa F401
|
||||
from .lora import ONNXModelPatcher, ModelPatcher
|
||||
from .loader import ModelLoader, ModelInfo # noqa F401
|
||||
from .install import ModelInstall, DownloadJobBase # noqa F401
|
||||
from .probe import ModelProbe, InvalidModelException # noqa F401
|
||||
from .storage import DuplicateModelException # noqa F401
|
||||
from .storage import (
|
||||
UnknownModelException,
|
||||
DuplicateModelException,
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreYAML,
|
||||
ModelConfigStoreSQL,
|
||||
) # noqa F401
|
||||
from .search import ModelSearch # noqa F401
|
||||
from .merge import MergeInterpolationMethod, ModelMerger
|
||||
|
@ -114,7 +114,7 @@ class ModelConfigBase(BaseModel):
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
model_format: ModelFormat
|
||||
id: Optional[str] = Field(None) # this may get added by the store
|
||||
key: Optional[str] = Field(None) # this will get added by the store
|
||||
description: Optional[str] = Field(None)
|
||||
author: Optional[str] = Field(description="Model author")
|
||||
license: Optional[str] = Field(description="License string")
|
||||
@ -244,6 +244,7 @@ class ModelConfigFactory(object):
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[dict, ModelConfigBase],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
) -> Union[
|
||||
MainCheckpointConfig,
|
||||
@ -263,6 +264,8 @@ class ModelConfigFactory(object):
|
||||
be selected automatically.
|
||||
"""
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
if key:
|
||||
model_data.key = key
|
||||
return model_data
|
||||
try:
|
||||
model_format = model_data.get("model_format")
|
||||
@ -271,7 +274,10 @@ class ModelConfigFactory(object):
|
||||
class_to_return = dest_class or cls._class_map[model_format][model_type]
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
return class_to_return.parse_obj(model_data)
|
||||
model = class_to_return.parse_obj(model_data)
|
||||
if key:
|
||||
model.key = key # ensure consistency
|
||||
return model
|
||||
except KeyError as exc:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"
|
||||
|
@ -49,6 +49,9 @@ class DownloadJobBase(BaseModel):
|
||||
id: int = Field(description="Numeric ID of this job")
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
|
@ -20,7 +20,7 @@ Typical usage:
|
||||
# register config, and install model in `models`
|
||||
id: str = installer.install_path('/path/to/model')
|
||||
|
||||
# download some remote models and install them in the background
|
||||
1 # download some remote models and install them in the background
|
||||
installer.install('stabilityai/stable-diffusion-2-1')
|
||||
installer.install('https://civitai.com/api/download/models/154208')
|
||||
installer.install('runwayml/stable-diffusion-v1-5')
|
||||
@ -58,7 +58,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .search import ModelSearch
|
||||
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
|
||||
from .storage import ModelConfigStore, DuplicateModelException, get_config_store
|
||||
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
|
||||
from .hash import FastModelHash
|
||||
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
|
||||
@ -272,7 +272,7 @@ class ModelInstall(ModelInstallBase):
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
|
||||
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
|
||||
self._store = store or get_config_store(self._config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue(config=self._config)
|
||||
self._async_installs = dict()
|
||||
self._installed = set()
|
||||
@ -289,7 +289,7 @@ class ModelInstall(ModelInstallBase):
|
||||
return self._register(model_path, info)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
id: str = FastModelHash.hash(model_path)
|
||||
key: str = FastModelHash.hash(model_path)
|
||||
registration_data = dict(
|
||||
path=model_path.as_posix(),
|
||||
name=model_path.stem,
|
||||
@ -309,13 +309,13 @@ class ModelInstall(ModelInstallBase):
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
|
||||
registration_data.update(
|
||||
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
|
||||
)
|
||||
self._store.add_model(id, registration_data)
|
||||
return id
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
@ -334,13 +334,13 @@ class ModelInstall(ModelInstallBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def unregister(self, id: str): # noqa D102
|
||||
self._store.del_model(id)
|
||||
def unregister(self, key: str): # noqa D102
|
||||
self._store.del_model(key)
|
||||
|
||||
def delete(self, id: str): # noqa D102
|
||||
model = self._store.get_model(id)
|
||||
def delete(self, key: str): # noqa D102
|
||||
model = self._store.get_model(key)
|
||||
rmtree(model.path)
|
||||
self.unregister(id)
|
||||
self.unregister(key)
|
||||
|
||||
def install(
|
||||
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
|
||||
@ -381,6 +381,7 @@ class ModelInstall(ModelInstallBase):
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
@ -421,8 +422,8 @@ class ModelInstall(ModelInstallBase):
|
||||
for model in self._store.all_models():
|
||||
path = Path(model.path)
|
||||
if not path.exists():
|
||||
self._store.del_model(model.id)
|
||||
unregistered.append(model.id)
|
||||
self._store.del_model(model.key)
|
||||
unregistered.append(model.key)
|
||||
return unregistered
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
|
@ -26,7 +26,7 @@ class ModelInfo:
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
id: str
|
||||
key: str
|
||||
location: Union[Path, str]
|
||||
precision: torch.dtype
|
||||
_cache: Optional[ModelCache] = None
|
||||
@ -186,7 +186,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
name=model_config.name,
|
||||
base_model=model_config.base_model,
|
||||
type=submodel_type or model_type,
|
||||
id=model_config.id,
|
||||
key=model_config.key,
|
||||
location=model_path,
|
||||
precision=self._cache.precision,
|
||||
_cache=self._cache,
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
invokeai.backend.model_manager.merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
@ -15,7 +15,7 @@ from typing import List, Union, Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from . import ModelLoader, ModelType, BaseModelType, ModelVariantType, ModelConfigBase
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
@ -26,7 +26,7 @@ class MergeInterpolationMethod(str, Enum):
|
||||
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
def __init__(self, manager: ModelLoader):
|
||||
self.manager = manager
|
||||
|
||||
def merge_diffusion_models(
|
||||
@ -77,7 +77,7 @@ class ModelMerger(object):
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
@ -1,6 +1,19 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.storage
|
||||
"""
|
||||
import pathlib
|
||||
|
||||
from .base import ModelConfigStore, UnknownModelException, DuplicateModelException # noqa F401
|
||||
from .yaml import ModelConfigStoreYAML # noqa F401
|
||||
from .sql import ModelConfigStoreSQL # noqa F401
|
||||
|
||||
|
||||
def get_config_store(location: pathlib.Path) -> ModelConfigStore:
|
||||
"""Return the type of ModelConfigStore appropriate to the path."""
|
||||
location = pathlib.Path(location)
|
||||
if location.suffix == ".yaml":
|
||||
return ModelConfigStoreYAML(location)
|
||||
elif location.suffix == ".db":
|
||||
return ModelConfigStoreSQL(location)
|
||||
else:
|
||||
raise Exception("Unable to determine type of configuration file '{location}'")
|
||||
|
@ -19,7 +19,7 @@ class InvalidModelException(Exception):
|
||||
|
||||
|
||||
class UnknownModelException(Exception):
|
||||
"""Raised on an attempt to delete a model with a nonexistent key."""
|
||||
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
|
||||
|
||||
|
||||
class ModelConfigStore(ABC):
|
||||
@ -90,7 +90,7 @@ class ModelConfigStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_type(
|
||||
def search_by_name(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
@ -112,4 +112,4 @@ class ModelConfigStore(ABC):
|
||||
"""
|
||||
Return all the model configs in the database.
|
||||
"""
|
||||
return self.search_by_type()
|
||||
return self.search_by_name()
|
||||
|
@ -16,7 +16,7 @@ Typical usage:
|
||||
tags=['sfw','cartoon']
|
||||
)
|
||||
|
||||
# adding - the key becomes the model's "id" field
|
||||
# adding - the key becomes the model's "key" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
@ -30,14 +30,14 @@ Typical usage:
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
assert new_config.id == 'key1'
|
||||
assert new_config.key == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_tag({'sfw','oss license'})
|
||||
configs = store.search_by_type(base_model='sd-2', model_type='main')
|
||||
configs = store.search_by_name(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
import threading
|
||||
@ -173,8 +173,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfig exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect.
|
||||
record.id = key # add the unique storage key to object
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -293,7 +292,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -309,7 +308,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
""",
|
||||
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount < 1:
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException
|
||||
if record.tags:
|
||||
self._update_tags(key, record.tags)
|
||||
@ -404,7 +403,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
self._lock.release()
|
||||
return results
|
||||
|
||||
def search_by_type(
|
||||
def search_by_name(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
|
@ -16,7 +16,7 @@ Typical usage:
|
||||
tags=['sfw','cartoon']
|
||||
)
|
||||
|
||||
# adding - the key becomes the model's "id" field
|
||||
# adding - the key becomes the model's "key" field
|
||||
store.add_model('key1', config)
|
||||
|
||||
# updating
|
||||
@ -30,18 +30,19 @@ Typical usage:
|
||||
# fetching config
|
||||
new_config = store.get_model('key1')
|
||||
print(new_config.name, new_config.base_model)
|
||||
assert new_config.id == 'key1'
|
||||
assert new_config.key == 'key1'
|
||||
|
||||
# deleting
|
||||
store.del_model('key1')
|
||||
|
||||
# searching
|
||||
configs = store.search_by_tag({'sfw','oss license'})
|
||||
configs = store.search_by_type(base_model='sd-2', model_type='main')
|
||||
configs = store.search_by_name(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
import threading
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Union, Set, List, Optional
|
||||
from omegaconf import OmegaConf
|
||||
@ -110,8 +111,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfig exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect
|
||||
record.id = key # add the key used to store the object
|
||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||
dict_fields = record.dict() # and back to a dict with valid fields
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -120,11 +120,18 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
raise DuplicateModelException(
|
||||
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
|
||||
)
|
||||
self._config[key] = dict_fields
|
||||
self._config[key] = self._fix_enums(dict_fields)
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _fix_enums(self, original: dict) -> dict:
|
||||
"""In python 3.9, omegaconf stores incorrectly stringified enums"""
|
||||
fixed_dict = {}
|
||||
for key, value in original.items():
|
||||
fixed_dict[key] = value.value if isinstance(value, Enum) else value
|
||||
return fixed_dict
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
@ -150,13 +157,13 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config) # ensure it is a valid config obect
|
||||
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
|
||||
dict_fields = record.dict() # and back to a dict with valid fields
|
||||
try:
|
||||
self._lock.acquire()
|
||||
if key not in self._config:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config")
|
||||
self._config[key] = dict_fields
|
||||
self._config[key] = self._fix_enums(dict_fields)
|
||||
self._commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -171,7 +178,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""
|
||||
try:
|
||||
record = self._config[key]
|
||||
return ModelConfigFactory.make_config(record)
|
||||
return ModelConfigFactory.make_config(record, key)
|
||||
except KeyError as e:
|
||||
raise UnknownModelException(f"Unknown key '{key}' for model config") from e
|
||||
|
||||
@ -202,7 +209,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
self._lock.release()
|
||||
return results
|
||||
|
||||
def search_by_type(
|
||||
def search_by_name(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
@ -224,7 +231,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
for key, record in self._config.items():
|
||||
if key == "__metadata__":
|
||||
continue
|
||||
model = ModelConfigFactory.make_config(record)
|
||||
model = ModelConfigFactory.make_config(record, key)
|
||||
if model_name and model.name != model_name:
|
||||
continue
|
||||
if base_model and model.base_model != base_model:
|
||||
|
@ -3,45 +3,47 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend import ModelConfigStore, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import ModelLoader
|
||||
|
||||
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
BASIC_MODEL_NAME = "sdxl-base-1-0"
|
||||
VAE_OVERRIDE_MODEL_NAME = "sdxl-base-with-custom-vae-1-0"
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME = "sdxl-base-with-empty-vae-1-0"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager(datadir) -> ModelManager:
|
||||
InvokeAIAppConfig.get_config(root=datadir)
|
||||
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
|
||||
def model_manager(datadir) -> ModelLoader:
|
||||
config = InvokeAIAppConfig(root=datadir, conf_path="configs/relative_sub.models.yaml")
|
||||
return ModelLoader(config=config)
|
||||
|
||||
|
||||
def test_get_model_names(model_manager: ModelManager):
|
||||
names = model_manager.model_names()
|
||||
def test_get_model_names(model_manager: ModelLoader):
|
||||
store = model_manager.store
|
||||
names = [x.name for x in store.all_models()]
|
||||
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
|
||||
|
||||
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelLoader, datadir: Path):
|
||||
models = model_manager.store.search_by_name(model_name=BASIC_MODEL_NAME)
|
||||
assert len(models) == 1
|
||||
model_config = models[0]
|
||||
top_model_path, is_override = model_manager._get_model_path(model_config)
|
||||
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
|
||||
assert top_model_path == expected_model_path
|
||||
assert not is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(
|
||||
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
|
||||
)
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelLoader, datadir: Path):
|
||||
models = model_manager.store.search_by_name(model_name=VAE_OVERRIDE_MODEL_NAME)
|
||||
assert len(models) == 1
|
||||
model_config = models[0]
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||
assert vae_model_path == expected_vae_path
|
||||
assert is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(
|
||||
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
|
||||
)
|
||||
def test_get_model_path_for_null_overridden_vae(model_manager: ModelLoader, datadir: Path):
|
||||
model_config = model_manager.store.search_by_name(model_name=VAE_NULL_OVERRIDE_MODEL_NAME)[0]
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
assert not is_override
|
||||
|
@ -1,22 +1,30 @@
|
||||
__metadata__:
|
||||
version: 3.0.0
|
||||
|
||||
sdxl/main/SDXL base:
|
||||
version: 3.1.0
|
||||
ed799245c762f6d0a9ddfd4e31fdb010:
|
||||
name: sdxl-base-1-0
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
base_model: sdxl
|
||||
model_type: main
|
||||
model_format: diffusers
|
||||
model_variant: normal
|
||||
description: SDXL base v1.0
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
||||
sdxl/main/SDXL with VAE:
|
||||
fa78e05dbf51c540ff9256eb65446fd6:
|
||||
name: sdxl-base-with-custom-vae-1-0
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
base_model: sdxl
|
||||
model_type: main
|
||||
model_variant: normal
|
||||
model_format: diffusers
|
||||
description: SDXL with customized VAE
|
||||
vae: sdxl/vae/sdxl-vae-fp16-fix/
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
||||
sdxl/main/SDXL with empty VAE:
|
||||
8a79e05d9f51c5ffff9256eb65446fd6:
|
||||
name: sdxl-base-with-empty-vae-1-0
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
base_model: sdxl
|
||||
model_type: main
|
||||
model_variant: normal
|
||||
model_format: diffusers
|
||||
description: SDXL with customized VAE
|
||||
vae: ''
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
@ -12,6 +12,7 @@ from invokeai.backend.model_manager.storage import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
DiffusersConfig,
|
||||
VaeDiffusersConfig,
|
||||
@ -113,14 +114,15 @@ def test_filter(store: ModelConfigStore):
|
||||
config3 = VaeDiffusersConfig(path="/tmp/config3", name="config3", base_model="sd-1", model_type="vae", tags=["sfw"])
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
matches = store.search_by_type(model_type="main")
|
||||
matches = store.search_by_name(model_type="main")
|
||||
assert len(matches) == 2
|
||||
assert matches[0].name in {"config1", "config2"}
|
||||
|
||||
matches = store.search_by_type(model_type="vae")
|
||||
matches = store.search_by_name(model_type="vae")
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].id == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert isinstance(matches[0].model_type, ModelType) # This tests that we get proper enums back
|
||||
|
||||
matches = store.search_by_tag(["sfw"])
|
||||
assert len(matches) == 3
|
||||
|
@ -3,6 +3,7 @@ Test the refactored model config classes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from hashlib import sha256
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -12,6 +13,7 @@ from invokeai.backend.model_manager.storage import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
DiffusersConfig,
|
||||
VaeDiffusersConfig,
|
||||
@ -20,6 +22,7 @@ from invokeai.backend.model_manager.config import (
|
||||
|
||||
@pytest.fixture
|
||||
def store(datadir) -> ModelConfigStore:
|
||||
print(f"DEBUG: datadir={datadir}")
|
||||
InvokeAIAppConfig.get_config(root=datadir)
|
||||
return ModelConfigStoreSQL(datadir / "databases" / "models.db")
|
||||
|
||||
@ -89,11 +92,14 @@ def test_delete(store: ModelConfigStore):
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
try:
|
||||
store.del_model("unknown")
|
||||
assert False, "expected delete of unknown model to raise exception"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
# a bug in sqlite3 in python 3.9 prevents DEL from returning number of
|
||||
# deleted rows!
|
||||
if sys.version_info.major == 3 and sys.version_info.minor > 9:
|
||||
try:
|
||||
store.del_model("unknown")
|
||||
assert False, "expected delete of unknown model to raise exception"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_exists(store: ModelConfigStore):
|
||||
@ -113,14 +119,15 @@ def test_filter(store: ModelConfigStore):
|
||||
config3 = VaeDiffusersConfig(path="/tmp/config3", name="config3", base_model="sd-1", model_type="vae", tags=["sfw"])
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
matches = store.search_by_type(model_type="main")
|
||||
matches = store.search_by_name(model_type="main")
|
||||
assert len(matches) == 2
|
||||
assert matches[0].name in {"config1", "config2"}
|
||||
|
||||
matches = store.search_by_type(model_type="vae")
|
||||
matches = store.search_by_name(model_type="vae")
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].id == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
assert isinstance(matches[0].model_type, ModelType) # This tests that we get proper enums back
|
||||
|
||||
matches = store.search_by_tag(["sfw"])
|
||||
assert len(matches) == 3
|
||||
|
Reference in New Issue
Block a user