refactor model_manager_service.py into small functional modules

This commit is contained in:
Lincoln Stein
2023-10-04 23:45:58 -04:00
parent cb0fdf3394
commit cd5d3e30c7
14 changed files with 940 additions and 69 deletions

View File

@ -5,7 +5,8 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import ModelInfo, SubModelType
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.loader import ModelInfo
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -0,0 +1,449 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, Set, Literal
from pydantic import Field, parse_obj_as
from pydantic.networks import AnyHttpUrl
from invokeai.backend import get_precision
from invokeai.backend.model_manager.install import ModelInstallBase, ModelInstall, ModelInstallJob
from invokeai.backend.model_manager import (
ModelConfigBase,
ModelSearch,
)
from invokeai.backend.model_manager.storage import ModelConfigStore
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.util.logging import InvokeAILogger, Logger
from .config import InvokeAIAppConfig
from .events import EventServiceBase
from .model_record_service import ModelRecordServiceBase
class ModelInstallServiceBase(ABC):
"""Responsible for downloading, installing and deleting models."""
@abstractmethod
def __init__(self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
event_bus: Optional[EventServiceBase] = None
):
"""
Initialize a ModelInstallService instance.
:param config: InvokeAIAppConfig object
:param store: Either a ModelRecordServiceBase object or a ModelConfigStore
:param event_bus: Optional EventServiceBase object. If provided,
Installation and download events will be sent to the event bus as "model_event".
"""
pass
@abstractmethod
def convert_model(
self,
key: str,
convert_dest_directory: Path,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
This will delete the cached version if there is any and delete the original
checkpoint file if it is in the models directory.
:param key: Unique key for the model to convert.
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
priority: int = 10,
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
:param model_attributes: Additional attributes to supplement/override
the model information gained from automated probing.
:param priority: Queue priority. Lower values have higher priority.
Typical usage:
job = model_manager.install(
'stabilityai/stable-diffusion-2-1',
model_attributes={'prediction_type": 'v-prediction'}
)
The result is an ModelInstallJob object, which provides
information on the asynchronous model download and install
process. A series of "install_model_event" events will be emitted
until the install is completed, cancelled or errors out.
"""
pass
@abstractmethod
def list_install_jobs(self) -> List[ModelInstallJob]:
"""Return a series of active or enqueued ModelInstallJobs."""
pass
@abstractmethod
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
pass
@abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def start_job(self, job_id: int):
"""Start the given install job if it is paused or idle."""
pass
@abstractmethod
def pause_job(self, job_id: int):
"""Pause the given install job if it is paused or idle."""
pass
@abstractmethod
def cancel_job(self, job_id: int):
"""Cancel the given install job."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active jobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Remove completed or errored install jobs."""
pass
@abstractmethod
def change_job_priority(self, job_id: int, delta: int):
"""
Change an install job's priority.
:param job_id: Job to change
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
manager.change_job_priority(-10).
"""
pass
@abstractmethod
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
pass
@abstractmethod
def search_for_models(self, directory: Path) -> Set[Path]:
"""Return list of all models found in the designated directory."""
pass
@abstractmethod
def sync_to_config(self):
"""
Synchronize the in-memory models with on-disk.
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
# implementation
class ModelInstallService(ModelInstallServiceBase):
"""Responsible for managing models on disk and in memory."""
_installer: ModelInstallBase = Field(description="ModelInstall object for the current process")
_config: InvokeAIAppConfig = Field(description="App configuration object")
_precision: Literal['float16', 'float32'] = Field(description="Floating point precision, string form")
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
_logger: Logger = Field(description="logger module")
def __init__(self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
event_bus: Optional[EventServiceBase] = None
):
"""
Initialize a ModelInstallService instance.
:param config: InvokeAIAppConfig object
:param store: Either a ModelRecordService object or a ModelConfigStore
:param event_bus: Optional EventServiceBase object. If provided,
Installation and download events will be sent to the event bus as "model_event".
"""
self._event_bus = event_bus
self._config = config
kwargs: Dict[str, Any] = {}
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
self._precision = get_precision()
self._installer = ModelInstall(store, config, **kwargs)
self._logger = InvokeAILogger.get_logger()
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
"""Call automatically at process start."""
self._installer.scan_models_directory() # synchronize new/deleted models found in models directory
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
priority: int = 10,
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""
Add a model using a path, repo_id or URL.
:param model_attributes: Dictionary of ModelConfigBase fields to
attach to the model. When installing a URL or repo_id, some metadata
values, such as `tags` will be automagically added.
:param priority: Queue priority for this install job. Lower value jobs
will run before higher value ones.
"""
self.logger.debug(f"add model {source}")
variant = "fp16" if self._precision == "float16" else None
job = self._installer.install(
source,
probe_override=model_attributes,
variant=variant,
priority=priority,
)
assert isinstance(job, ModelInstallJob)
return job
def list_install_jobs(self) -> List[ModelInstallJob]:
"""Return a series of active or enqueued ModelInstallJobs."""
queue = self._installer.queue
jobs: List[DownloadJobBase] = queue.list_jobs()
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
job = self._installer.queue.id_to_job(id)
assert isinstance(job, ModelInstallJob)
return job
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
return self._installer.wait_for_installs()
def start_job(self, job_id: int):
"""Start the given install job if it is paused or idle."""
queue = self._installer.queue
queue.start_job(queue.id_to_job(job_id))
def pause_job(self, job_id: int):
"""Pause the given install job if it is paused or idle."""
queue = self._installer.queue
queue.pause_job(queue.id_to_job(job_id))
def cancel_job(self, job_id: int):
"""Cancel the given install job."""
queue = self._installer.queue
queue.cancel_job(queue.id_to_job(job_id))
def cancel_all_jobs(self):
"""Cancel all active install job."""
queue = self._loader.queue
queue.cancel_all_jobs()
def prune_jobs(self):
"""Cancel all active install job."""
queue = self._loader.queue
queue.prune_jobs()
def change_job_priority(self, job_id: int, delta: int):
"""
Change an install job's priority.
:param job_id: Job to change
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
manager.change_job_priority(-10).
"""
queue = self._installer.queue
queue.change_priority(queue.id_to_job(job_id), delta)
def del_model(
self,
key: str,
delete_files: bool = False,
):
"""
Delete the named model from configuration.
If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
model_info = self.store.get_model(key)
self.logger.debug(f"delete model {model_info.name}")
self.store.del_model(key)
if delete_files and Path(self._config.models_path / model_info.path).exists():
path = Path(model_info.path)
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
def convert_model(
self,
key: str,
convert_dest_directory: Path,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
Delete the cached
version and delete the original checkpoint file if it is in the models
directory.
:param key: Key of the model to convert
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
model_info = self.store.get_model(key)
self.logger.info(f"Converting model {model_info.name} into a diffusers")
return self._installer.convert_model(key, convert_dest_directory)
@property
def logger(self):
"""Get the logger associated with this instance."""
return self._loader.logger
@property
def store(self):
"""Get the store associated with this instance."""
return self._installer.store
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.store)
try:
if not merged_model_name:
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
raise Exception("not implemented")
result = merger.merge_diffusion_models_and_save(
model_keys=model_keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path) -> Set[Path]:
"""
Return list of all models found in the designated directory.
:param directory: Path to the directory to recursively search.
returns a list of model paths
"""
return ModelSearch().search(directory)
def sync_to_config(self):
"""
Synchronize the model manager to the database.
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self._installer.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
config = self._config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]

View File

@ -0,0 +1,148 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Union, Optional
from pydantic import Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_manager import (
ModelConfigStore,
SubModelType,
)
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
from .config import InvokeAIAppConfig
from .events import EventServiceBase
from .model_record_service import ModelRecordServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext
class ModelLoadServiceBase(ABC):
"""Load models into memory."""
@abstractmethod
def __init__(self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
event_bus: Optional[EventServiceBase] = None):
"""
Initialize a ModelLoadService
:param config: InvokeAIAppConfig object
:param store: ModelConfigStore object for fetching configuration information
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus.
"""
pass
@abstractmethod
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model identified by key.
:param key: Unique key returned by the ModelConfigStore module.
:param submodel_type: Submodel to return (required for main models)
:param context" Optional InvocationContext, used in event reporting.
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Reset model cache statistics for graph with graph_id."""
pass
# implementation
class ModelLoadService(ModelLoadServiceBase):
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
def __init__(self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
event_bus: Optional[EventServiceBase] = None
):
"""
Initialize a ModelManagerService.
:param config: InvokeAIAppConfig object
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus.
"""
self._event_bus = event_bus
kwargs: Dict[str, Any] = {}
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
self._loader = ModelLoad(config, store, **kwargs)
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model.
The submodel is required when fetching a main model.
"""
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
context=context,
model_key=key,
submodel=submodel_type,
model_info=model_info,
)
return model_info
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics. Is this used?
"""
self._loader.collect_cache_stats(cache_stats)
def _emit_load_event(
self,
context: InvocationContext,
model_key: str,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_key=model_key,
submodel=submodel,
model_info=model_info,
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_key=model_key,
submodel=submodel,
)

View File

@ -0,0 +1,301 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import shutil
import sqlite3
import threading
from abc import ABC, abstractmethod
from pathlib import Path
from pydantic import Field
from typing import List, Optional, Union
from invokeai.backend.model_manager.storage import (
AnyModelConfig,
ModelConfigStore,
ModelConfigStoreSQL,
ModelConfigStoreYAML,
)
from invokeai.backend.model_manager import (
ModelConfigBase,
BaseModelType,
ModelType,
DuplicateModelException,
UnknownModelException,
)
class ModelRecordServiceBase(ABC):
"""Responsible for managing model configuration records."""
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def model_exists(
self,
key: str,
) -> bool:
"""Return true if the model configuration identified by key exists in the database."""
pass
@abstractmethod
def model_info(self, key: str) -> ModelConfigBase:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Uses the exact format as the omegaconf stanza.
"""
pass
@abstractmethod
def list_models(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return a list of ModelConfigBases that match the base, type and name criteria.
:param base_model: Filter by the base model type.
:param model_type: Filter by the model type.
:param model_name: Filter by the model name.
"""
pass
@abstractmethod
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
"""
Return information about the model using the same format as list_models().
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
pass
model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
)
if len(model_configs) == 0:
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
return model_configs[0]
@abstractmethod
def all_models(self) -> List[ModelConfigBase]:
"""Return a list of all the models."""
pass
return self.list_models()
@abstractmethod
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfig exceptions.
"""
pass
@abstractmethod
def update_model(
self,
key: str,
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes.
Will fail with a
UnknownModelException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model key is unknown.
"""
pass
@abstractmethod
def del_model(self, key: str, delete_files: bool = False):
"""
Delete the named model from configuration. If delete_files
is true, then the underlying file or directory will be
deleted as well.
"""
pass
def rename_model(
self,
key: str,
new_name: str,
) -> ModelConfigBase:
"""
Rename the indicated model.
"""
return self.update_model(key, {"name": new_name})
# implementation base class
class ModelRecordService(ModelRecordServiceBase):
"""Responsible for managing models on disk and in memory."""
_store: ModelConfigStore = Field(description="Config record storage backend")
@abstractmethod
def __init__(self):
"""Initialize object -- abstract method."""
pass
def get_model(
self,
key: str,
) -> AnyModelConfig:
"""
Retrieve the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
return self._store.get_model(key)
def model_exists(
self,
key: str,
) -> bool:
"""
Verify that a model with the given key exists.
Given a model key, returns True if it is a valid
identifier.
"""
return self._store.exists(key)
def model_info(self, key: str) -> ModelConfigBase:
"""
Return configuration information about a model.
Given a model key returns the ModelConfigBase describing it.
"""
return self._store.get_model(key)
def list_models(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return a ModelConfigBase object for each model in the database.
"""
return self._store.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
"""
Return information about the model using the same format as list_models().
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
)
if len(model_configs) == 0:
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
return model_configs[0]
def all_models(self) -> List[ModelConfigBase]:
"""Return a list of all the models."""
return self.list_models()
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfig exceptions.
"""
self._store.add_model(key, config)
def update_model(
self,
key: str,
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes.
Will fail with a
UnknownModelException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model key is unknown.
"""
new_info = self._store.update_model(key, new_config)
print('FIX ME!!! need to call sync_model_path() somewhere... maybe router?')
# self._loader.installer.sync_model_path(new_info.key) Maybe this goes into the router call?
return new_info
def del_model(
self,
key: str,
):
"""
Delete the named model from configuration.
"""
model_info = self.model_info(key)
self._store.del_model(key)
def rename_model(
self,
key: str,
new_name: str,
):
"""
Rename the indicated model to the new name.
:param key: Unique key for the model.
:param new_name: New name for the model
"""
return self.update_model(key, {"name": new_name})
class ModelRecordServiceSQL(ModelRecordService):
"""ModelRecordService that uses Sqlite for its backend."""
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock):
"""
Initialize a ModelRecordService that uses a SQLITE3 database backend.
:param conn: sqlite3 Connection object
:param lock: Thread lock object
"""
self._store = ModelConfigStoreSQL(conn, lock)
class ModelRecordServiceFile(ModelRecordService):
"""ModelRecordService that uses a YAML file for its backend."""
def __init__(self, models_file: Path):
"""
Initialize a ModelRecordService that uses a YAML file as the backend.
:param models_file: Path to the YAML file backend.
"""
self._store = ModelConfigStoreYAML(models_file)

View File

@ -6,11 +6,12 @@ from .model_manager import ( # noqa F401
DuplicateModelException,
InvalidModelException,
ModelConfigStore,
ModelInstall,
ModelLoad,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
SubModelType,
)
from .model_manager.install import ModelInstall # noqa F401
from .model_manager.loader import ModelLoad # noqa F401
from .util.devices import get_precision # noqa F401

View File

@ -11,9 +11,9 @@ from .config import ( # noqa F401
SilenceWarnings,
SubModelType,
)
from .install import ModelInstall, ModelInstallJob # noqa F401
from .loader import ModelInfo, ModelLoad # noqa F401
from .lora import ModelPatcher, ONNXModelPatcher # noqa F401
# from .install import ModelInstall, ModelInstallJob # noqa F401
# from .loader import ModelInfo, ModelLoad # noqa F401
# from .lora import ModelPatcher, ONNXModelPatcher # noqa F401
from .models import OPENAPI_MODEL_CONFIGS, InvalidModelException, read_checkpoint_meta # noqa F401
from .probe import ModelProbe, ModelProbeInfo # noqa F401
from .search import ModelSearch # noqa F401

View File

@ -60,6 +60,7 @@ from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
from .config import (
@ -319,6 +320,11 @@ class ModelInstallBase(ABC):
"""
pass
@abstractmethod
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
pass
@abstractmethod
def scan_models_directory(self):
"""
@ -367,7 +373,7 @@ class ModelInstall(ModelInstallBase):
def __init__(
self,
store: Optional[ModelConfigStore] = None,
store: Optional[Union[ModelConfigStore, ModelRecordServiceBase]] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[Logger] = None,
download: Optional[DownloadQueueBase] = None,
@ -381,10 +387,6 @@ class ModelInstall(ModelInstallBase):
self._installed = set()
self._tmpdir = None
# this step synchronizes the `models` directory with the models db
# do NOT do this automatically, but only on app startup
# self.scan_models_directory()
@property
def queue(self) -> DownloadQueueBase:
"""Return the queue."""
@ -609,7 +611,7 @@ class ModelInstall(ModelInstallBase):
variant: Optional[str] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> DownloadJobBase:
) -> ModelInstallJob:
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
source = source.strip()
@ -747,6 +749,10 @@ class ModelInstall(ModelInstallBase):
pass
return True
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
self.scan_models_directory()
def scan_models_directory(self):
"""
Scan the models directory for new and missing models.

View File

@ -11,11 +11,10 @@ import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from .cache import CacheStats, ModelCache
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
from .download import DownloadEventHandler, DownloadQueueBase
from .install import ModelInstall, ModelInstallBase
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
@ -65,12 +64,6 @@ class ModelLoadBase(ABC):
"""Return the ModelConfigStore object that supports this loader."""
pass
@property
@abstractmethod
def installer(self) -> ModelInstallBase:
"""Return the ModelInstallBase object that supports this loader."""
pass
@property
@abstractmethod
def logger(self) -> Logger:
@ -80,13 +73,7 @@ class ModelLoadBase(ABC):
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the config object used by this installer."""
pass
@property
@abstractmethod
def queue(self) -> DownloadQueueBase:
"""Return the download queue object used by this object."""
"""Return the config object used by the loader."""
pass
@abstractmethod
@ -105,18 +92,12 @@ class ModelLoadBase(ABC):
"""Return torch.fp16 or torch.fp32."""
pass
@abstractmethod
def sync_to_config(self):
"""Reinitialize the store to sync in-memory and in-disk versions."""
pass
class ModelLoad(ModelLoadBase):
"""Implementation of ModelLoadBase."""
_app_config: InvokeAIAppConfig
_store: ModelConfigStore
_installer: ModelInstallBase
_store: Union[ModelConfigStore, ModelRecordServiceBase]
_cache: ModelCache
_logger: Logger
_cache_keys: dict
@ -126,7 +107,6 @@ class ModelLoad(ModelLoadBase):
self,
config: InvokeAIAppConfig,
store: Optional[ModelConfigStore] = None,
event_handlers: List[DownloadEventHandler] = [],
):
"""
Initialize ModelLoad object.
@ -138,12 +118,6 @@ class ModelLoad(ModelLoadBase):
self._app_config = config
self._store = store
self._logger = InvokeAILogger.get_logger()
self._installer = ModelInstall(
store=self._store,
logger=self._logger,
config=self._app_config,
event_handlers=event_handlers,
)
self._cache_keys = dict()
self._models_file = config.model_conf_path
device = torch.device(choose_torch_device())
@ -190,11 +164,6 @@ class ModelLoad(ModelLoadBase):
"""Return torch.fp16 or torch.fp32."""
return self._cache.precision
@property
def installer(self) -> ModelInstallBase:
"""Return the ModelInstallBase instance used by this class."""
return self._installer
@property
def logger(self) -> Logger:
"""Return the current logger."""
@ -202,14 +171,9 @@ class ModelLoad(ModelLoadBase):
@property
def config(self) -> InvokeAIAppConfig:
"""Return the config object used by the installer."""
"""Return the config object."""
return self._app_config
@property
def queue(self) -> DownloadQueueBase:
"""Return the download queue object used by this object."""
return self._installer.queue
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Get the ModelInfo corresponding to the model with key "key".
@ -303,7 +267,3 @@ class ModelLoad(ModelLoadBase):
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
self._store = get_config_store(self._models_file)
self.installer.scan_models_directory()

View File

@ -84,9 +84,9 @@ class ModelMerger(object):
self,
model_keys: List[str],
merged_model_name: str,
alpha: float = 0.5,
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> ModelConfigBase:

View File

@ -10,7 +10,7 @@ from .base import ( # noqa F401
from .migrate import migrate_models_store # noqa F401
from .sql import ModelConfigStoreSQL # noqa F401
from .yaml import ModelConfigStoreYAML # noqa F401
from ..config import AnyModelConfig # noqa F401
def get_config_store(location: pathlib.Path) -> ModelConfigStore:
"""Return the type of ModelConfigStore appropriate to the path."""

View File

@ -39,7 +39,7 @@ class ModelConfigStore(ABC):
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> None:
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
"""
Add a model to the database.

View File

@ -53,21 +53,18 @@ from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore
class ModelConfigStoreSQL(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
_filename: Path
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: Path):
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock):
"""Initialize ModelConfigStore object with a sqlite3 database."""
super().__init__()
self._filename = Path(filename).absolute() # don't let chdir mess us up!
self._filename.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(filename, check_same_thread=False)
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
self._lock = lock
try:
self._lock.acquire()
@ -174,7 +171,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Add a model to the database.
@ -225,6 +222,8 @@ class ModelConfigStoreSQL(ModelConfigStore):
finally:
self._lock.release()
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""

View File

@ -104,7 +104,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""Return version of this config file/database."""
return self._config.__metadata__.get("version")
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Add a model to the database.
@ -127,6 +127,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
self._commit()
finally:
self._lock.release()
return self.get_model(key)
def _fix_enums(self, original: dict) -> dict:
"""In python 3.9, omegaconf stores incorrectly stringified enums."""

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import platform
from contextlib import nullcontext
from typing import Union
from typing import Union, Literal
import torch
from packaging import version
@ -41,6 +41,11 @@ def choose_precision(device: torch.device) -> str:
return "float16"
return "float32"
def get_precision() -> Literal['float16', 'float32']:
device = torch.device(choose_torch_device())
precision = choose_precision(device) if config.precision == "auto" else config.precision
assert precision in ['float16', 'float32']
return precision
def torch_dtype(device: torch.device) -> torch.dtype:
if config.full_precision: