From cd5d3e30c71df41e08e0d470cdaad2af6e42303f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 4 Oct 2023 23:45:58 -0400 Subject: [PATCH] refactor model_manager_service.py into small functional modules --- invokeai/app/services/events.py | 3 +- .../app/services/model_install_service.py | 449 ++++++++++++++++++ invokeai/app/services/model_loader_service.py | 148 ++++++ invokeai/app/services/model_record_service.py | 301 ++++++++++++ invokeai/backend/__init__.py | 5 +- invokeai/backend/model_manager/__init__.py | 6 +- invokeai/backend/model_manager/install.py | 18 +- invokeai/backend/model_manager/loader.py | 48 +- invokeai/backend/model_manager/merge.py | 4 +- .../backend/model_manager/storage/__init__.py | 2 +- .../backend/model_manager/storage/base.py | 2 +- invokeai/backend/model_manager/storage/sql.py | 13 +- .../backend/model_manager/storage/yaml.py | 3 +- invokeai/backend/util/devices.py | 7 +- 14 files changed, 940 insertions(+), 69 deletions(-) create mode 100644 invokeai/app/services/model_install_service.py create mode 100644 invokeai/app/services/model_loader_service.py create mode 100644 invokeai/app/services/model_record_service.py diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index b07dc79237..f9732851db 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -5,7 +5,8 @@ from typing import Any, Optional from invokeai.app.models.image import ProgressImage from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import ModelInfo, SubModelType +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.loader import ModelInfo from invokeai.backend.model_manager.download import DownloadJobBase from invokeai.backend.util.logging import InvokeAILogger diff --git a/invokeai/app/services/model_install_service.py b/invokeai/app/services/model_install_service.py new file mode 100644 index 0000000000..18d568724e --- /dev/null +++ b/invokeai/app/services/model_install_service.py @@ -0,0 +1,449 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team + +from __future__ import annotations + +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, Set, Literal + +from pydantic import Field, parse_obj_as +from pydantic.networks import AnyHttpUrl + +from invokeai.backend import get_precision +from invokeai.backend.model_manager.install import ModelInstallBase, ModelInstall, ModelInstallJob +from invokeai.backend.model_manager import ( + ModelConfigBase, + ModelSearch, +) +from invokeai.backend.model_manager.storage import ModelConfigStore +from invokeai.backend.model_manager.download import DownloadJobBase +from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger +from invokeai.backend.util.logging import InvokeAILogger, Logger + +from .config import InvokeAIAppConfig +from .events import EventServiceBase +from .model_record_service import ModelRecordServiceBase + + +class ModelInstallServiceBase(ABC): + """Responsible for downloading, installing and deleting models.""" + + @abstractmethod + def __init__(self, + config: InvokeAIAppConfig, + store: Union[ModelConfigStore, ModelRecordServiceBase], + event_bus: Optional[EventServiceBase] = None + ): + """ + Initialize a ModelInstallService instance. + + :param config: InvokeAIAppConfig object + :param store: Either a ModelRecordServiceBase object or a ModelConfigStore + :param event_bus: Optional EventServiceBase object. If provided, + + Installation and download events will be sent to the event bus as "model_event". + """ + pass + + @abstractmethod + def convert_model( + self, + key: str, + convert_dest_directory: Path, + ) -> ModelConfigBase: + """ + Convert a checkpoint file into a diffusers folder. + + This will delete the cached version if there is any and delete the original + checkpoint file if it is in the models directory. + :param key: Unique key for the model to convert. + :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) + + This will raise a ValueError unless the model is not a checkpoint. It will + also raise a ValueError in the event that there is a similarly-named diffusers + directory already in place. + """ + pass + + @abstractmethod + def install_model( + self, + source: Union[str, Path, AnyHttpUrl], + priority: int = 10, + model_attributes: Optional[Dict[str, Any]] = None, + ) -> ModelInstallJob: + """Import a path, repo_id or URL. Returns an ModelInstallJob. + + :param model_attributes: Additional attributes to supplement/override + the model information gained from automated probing. + :param priority: Queue priority. Lower values have higher priority. + + Typical usage: + job = model_manager.install( + 'stabilityai/stable-diffusion-2-1', + model_attributes={'prediction_type": 'v-prediction'} + ) + + The result is an ModelInstallJob object, which provides + information on the asynchronous model download and install + process. A series of "install_model_event" events will be emitted + until the install is completed, cancelled or errors out. + """ + pass + + @abstractmethod + def list_install_jobs(self) -> List[ModelInstallJob]: + """Return a series of active or enqueued ModelInstallJobs.""" + pass + + @abstractmethod + def id_to_job(self, id: int) -> ModelInstallJob: + """Return the ModelInstallJob instance corresponding to the given job ID.""" + pass + + @abstractmethod + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: + """ + Wait for all pending installs to complete. + + This will block until all pending downloads have + completed, been cancelled, or errored out. It will + block indefinitely if one or more jobs are in the + paused state. + + It will return a dict that maps the source model + path, URL or repo_id to the ID of the installed model. + """ + pass + + @abstractmethod + def start_job(self, job_id: int): + """Start the given install job if it is paused or idle.""" + pass + + @abstractmethod + def pause_job(self, job_id: int): + """Pause the given install job if it is paused or idle.""" + pass + + @abstractmethod + def cancel_job(self, job_id: int): + """Cancel the given install job.""" + pass + + @abstractmethod + def cancel_all_jobs(self): + """Cancel all active jobs.""" + pass + + @abstractmethod + def prune_jobs(self): + """Remove completed or errored install jobs.""" + pass + + @abstractmethod + def change_job_priority(self, job_id: int, delta: int): + """ + Change an install job's priority. + + :param job_id: Job to change + :param delta: Value to increment or decrement priority. + + Lower values are higher priority. The default starting value is 10. + Thus to make this a really high priority job: + manager.change_job_priority(-10). + """ + pass + + @abstractmethod + def merge_models( + self, + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" + ), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: + """ + Merge two to three diffusrs pipeline models and save as a new model. + + :param model_keys: List of 2-3 model unique keys to merge + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) + """ + pass + + @abstractmethod + def list_checkpoint_configs(self) -> List[Path]: + """List the checkpoint config paths from ROOT/configs/stable-diffusion.""" + pass + + @abstractmethod + def search_for_models(self, directory: Path) -> Set[Path]: + """Return list of all models found in the designated directory.""" + pass + + @abstractmethod + def sync_to_config(self): + """ + Synchronize the in-memory models with on-disk. + + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + pass + + +# implementation +class ModelInstallService(ModelInstallServiceBase): + """Responsible for managing models on disk and in memory.""" + + _installer: ModelInstallBase = Field(description="ModelInstall object for the current process") + _config: InvokeAIAppConfig = Field(description="App configuration object") + _precision: Literal['float16', 'float32'] = Field(description="Floating point precision, string form") + _event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None) + _logger: Logger = Field(description="logger module") + + def __init__(self, + config: InvokeAIAppConfig, + store: Union[ModelConfigStore, ModelRecordServiceBase], + event_bus: Optional[EventServiceBase] = None + ): + """ + Initialize a ModelInstallService instance. + + :param config: InvokeAIAppConfig object + :param store: Either a ModelRecordService object or a ModelConfigStore + :param event_bus: Optional EventServiceBase object. If provided, + + Installation and download events will be sent to the event bus as "model_event". + """ + self._event_bus = event_bus + self._config = config + kwargs: Dict[str, Any] = {} + if self._event_bus: + kwargs.update(event_handlers=[self._event_bus.emit_model_event]) + self._precision = get_precision() + self._installer = ModelInstall(store, config, **kwargs) + self._logger = InvokeAILogger.get_logger() + + def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any' + """Call automatically at process start.""" + self._installer.scan_models_directory() # synchronize new/deleted models found in models directory + + def install_model( + self, + source: Union[str, Path, AnyHttpUrl], + priority: int = 10, + model_attributes: Optional[Dict[str, Any]] = None, + ) -> ModelInstallJob: + """ + Add a model using a path, repo_id or URL. + + :param model_attributes: Dictionary of ModelConfigBase fields to + attach to the model. When installing a URL or repo_id, some metadata + values, such as `tags` will be automagically added. + :param priority: Queue priority for this install job. Lower value jobs + will run before higher value ones. + """ + self.logger.debug(f"add model {source}") + variant = "fp16" if self._precision == "float16" else None + job = self._installer.install( + source, + probe_override=model_attributes, + variant=variant, + priority=priority, + ) + assert isinstance(job, ModelInstallJob) + return job + + def list_install_jobs(self) -> List[ModelInstallJob]: + """Return a series of active or enqueued ModelInstallJobs.""" + queue = self._installer.queue + jobs: List[DownloadJobBase] = queue.list_jobs() + return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type + + def id_to_job(self, id: int) -> ModelInstallJob: + """Return the ModelInstallJob instance corresponding to the given job ID.""" + job = self._installer.queue.id_to_job(id) + assert isinstance(job, ModelInstallJob) + return job + + def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]: + """ + Wait for all pending installs to complete. + + This will block until all pending downloads have + completed, been cancelled, or errored out. It will + block indefinitely if one or more jobs are in the + paused state. + + It will return a dict that maps the source model + path, URL or repo_id to the ID of the installed model. + """ + return self._installer.wait_for_installs() + + def start_job(self, job_id: int): + """Start the given install job if it is paused or idle.""" + queue = self._installer.queue + queue.start_job(queue.id_to_job(job_id)) + + def pause_job(self, job_id: int): + """Pause the given install job if it is paused or idle.""" + queue = self._installer.queue + queue.pause_job(queue.id_to_job(job_id)) + + def cancel_job(self, job_id: int): + """Cancel the given install job.""" + queue = self._installer.queue + queue.cancel_job(queue.id_to_job(job_id)) + + def cancel_all_jobs(self): + """Cancel all active install job.""" + queue = self._loader.queue + queue.cancel_all_jobs() + + def prune_jobs(self): + """Cancel all active install job.""" + queue = self._loader.queue + queue.prune_jobs() + + def change_job_priority(self, job_id: int, delta: int): + """ + Change an install job's priority. + + :param job_id: Job to change + :param delta: Value to increment or decrement priority. + + Lower values are higher priority. The default starting value is 10. + Thus to make this a really high priority job: + manager.change_job_priority(-10). + """ + queue = self._installer.queue + queue.change_priority(queue.id_to_job(job_id), delta) + + def del_model( + self, + key: str, + delete_files: bool = False, + ): + """ + Delete the named model from configuration. + + If delete_files is true, + then the underlying weight file or diffusers directory will be deleted + as well. + """ + model_info = self.store.get_model(key) + self.logger.debug(f"delete model {model_info.name}") + self.store.del_model(key) + if delete_files and Path(self._config.models_path / model_info.path).exists(): + path = Path(model_info.path) + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + + def convert_model( + self, + key: str, + convert_dest_directory: Path, + ) -> ModelConfigBase: + """ + Convert a checkpoint file into a diffusers folder. + + Delete the cached + version and delete the original checkpoint file if it is in the models + directory. + + :param key: Key of the model to convert + :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) + + This will raise a ValueError unless the model is not a checkpoint. It will + also raise a ValueError in the event that there is a similarly-named diffusers + directory already in place. + """ + model_info = self.store.get_model(key) + self.logger.info(f"Converting model {model_info.name} into a diffusers") + return self._installer.convert_model(key, convert_dest_directory) + + @property + def logger(self): + """Get the logger associated with this instance.""" + return self._loader.logger + + @property + def store(self): + """Get the store associated with this instance.""" + return self._installer.store + + def merge_models( + self, + model_keys: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model keys to merge" + ), + merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, + ) -> ModelConfigBase: + """ + Merge two to three diffusrs pipeline models and save as a new model. + + :param model_keys: List of 2-3 model unique keys to merge + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) + """ + merger = ModelMerger(self.store) + try: + if not merged_model_name: + merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys]) + raise Exception("not implemented") + + result = merger.merge_diffusion_models_and_save( + model_keys=model_keys, + merged_model_name=merged_model_name, + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory=merge_dest_directory, + ) + except AssertionError as e: + raise ValueError(e) + return result + + def search_for_models(self, directory: Path) -> Set[Path]: + """ + Return list of all models found in the designated directory. + + :param directory: Path to the directory to recursively search. + returns a list of model paths + """ + return ModelSearch().search(directory) + + def sync_to_config(self): + """ + Synchronize the model manager to the database. + + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + return self._installer.sync_to_config() + + def list_checkpoint_configs(self) -> List[Path]: + """List the checkpoint config paths from ROOT/configs/stable-diffusion.""" + config = self._config + conf_path = config.legacy_conf_path + root_path = config.root_path + return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] diff --git a/invokeai/app/services/model_loader_service.py b/invokeai/app/services/model_loader_service.py new file mode 100644 index 0000000000..7941fc62e5 --- /dev/null +++ b/invokeai/app/services/model_loader_service.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Union, Optional +from pydantic import Field + +from invokeai.app.models.exceptions import CanceledException +from invokeai.backend.model_manager import ( + ModelConfigStore, + SubModelType, +) +from invokeai.backend.model_manager.cache import CacheStats +from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad + +from .config import InvokeAIAppConfig +from .events import EventServiceBase +from .model_record_service import ModelRecordServiceBase + +if TYPE_CHECKING: + from ..invocations.baseinvocation import InvocationContext + + +class ModelLoadServiceBase(ABC): + """Load models into memory.""" + + @abstractmethod + def __init__(self, + config: InvokeAIAppConfig, + store: Union[ModelConfigStore, ModelRecordServiceBase], + event_bus: Optional[EventServiceBase] = None): + """ + Initialize a ModelLoadService + + :param config: InvokeAIAppConfig object + :param store: ModelConfigStore object for fetching configuration information + :param event_bus: Optional EventServiceBase object. If provided, + installation and download events will be sent to the event bus. + """ + pass + + @abstractmethod + def get_model( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> ModelInfo: + """Retrieve the indicated model identified by key. + + :param key: Unique key returned by the ModelConfigStore module. + :param submodel_type: Submodel to return (required for main models) + :param context" Optional InvocationContext, used in event reporting. + """ + pass + + @abstractmethod + def collect_cache_stats(self, cache_stats: CacheStats): + """Reset model cache statistics for graph with graph_id.""" + pass + + +# implementation +class ModelLoadService(ModelLoadServiceBase): + """Responsible for managing models on disk and in memory.""" + + _loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process") + _event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None) + + def __init__(self, + config: InvokeAIAppConfig, + store: Union[ModelConfigStore, ModelRecordServiceBase], + event_bus: Optional[EventServiceBase] = None + ): + """ + Initialize a ModelManagerService. + + :param config: InvokeAIAppConfig object + :param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information + :param event_bus: Optional EventServiceBase object. If provided, + installation and download events will be sent to the event bus. + """ + self._event_bus = event_bus + kwargs: Dict[str, Any] = {} + if self._event_bus: + kwargs.update(event_handlers=[self._event_bus.emit_model_event]) + self._loader = ModelLoad(config, store, **kwargs) + + def get_model( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> ModelInfo: + """ + Retrieve the indicated model. + + The submodel is required when fetching a main model. + """ + model_info: ModelInfo = self._loader.get_model(key, submodel_type) + + # we can emit model loading events if we are executing with access to the invocation context + if context: + self._emit_load_event( + context=context, + model_key=key, + submodel=submodel_type, + model_info=model_info, + ) + + return model_info + + def collect_cache_stats(self, cache_stats: CacheStats): + """ + Reset model cache statistics. Is this used? + """ + self._loader.collect_cache_stats(cache_stats) + + def _emit_load_event( + self, + context: InvocationContext, + model_key: str, + submodel: Optional[SubModelType] = None, + model_info: Optional[ModelInfo] = None, + ): + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if model_info: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_key=model_key, + submodel=submodel, + model_info=model_info, + ) + else: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_key=model_key, + submodel=submodel, + ) diff --git a/invokeai/app/services/model_record_service.py b/invokeai/app/services/model_record_service.py new file mode 100644 index 0000000000..e600f24005 --- /dev/null +++ b/invokeai/app/services/model_record_service.py @@ -0,0 +1,301 @@ +# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team + +from __future__ import annotations + +import shutil +import sqlite3 +import threading +from abc import ABC, abstractmethod +from pathlib import Path +from pydantic import Field +from typing import List, Optional, Union + +from invokeai.backend.model_manager.storage import ( + AnyModelConfig, + ModelConfigStore, + ModelConfigStoreSQL, + ModelConfigStoreYAML, +) +from invokeai.backend.model_manager import ( + ModelConfigBase, + BaseModelType, + ModelType, + DuplicateModelException, + UnknownModelException, +) + + +class ModelRecordServiceBase(ABC): + """Responsible for managing model configuration records.""" + + @abstractmethod + def get_model(self, key: str) -> AnyModelConfig: + """ + Retrieve the configuration for the indicated model. + + :param key: Key of model config to be fetched. + + Exceptions: UnknownModelException + """ + pass + + @abstractmethod + def model_exists( + self, + key: str, + ) -> bool: + """Return true if the model configuration identified by key exists in the database.""" + pass + + @abstractmethod + def model_info(self, key: str) -> ModelConfigBase: + """ + Given a model name returns a dict-like (OmegaConf) object describing it. + Uses the exact format as the omegaconf stanza. + """ + pass + + @abstractmethod + def list_models( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[ModelConfigBase]: + """ + Return a list of ModelConfigBases that match the base, type and name criteria. + :param base_model: Filter by the base model type. + :param model_type: Filter by the model type. + :param model_name: Filter by the model name. + """ + pass + + @abstractmethod + def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase: + """ + Return information about the model using the same format as list_models(). + + If there are more than one model that match, raises a DuplicateModelException. + If no model matches, raises an UnknownModelException + """ + pass + model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type) + if len(model_configs) > 1: + raise DuplicateModelException( + "More than one model share the same name and type: {base_model}/{model_type}/{model_name}" + ) + if len(model_configs) == 0: + raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}") + return model_configs[0] + + @abstractmethod + def all_models(self) -> List[ModelConfigBase]: + """Return a list of all the models.""" + pass + return self.list_models() + + @abstractmethod + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: + """ + Add a model to the database. + + :param key: Unique key for the model + :param config: Model configuration record, either a dict with the + required fields or a ModelConfigBase instance. + + Can raise DuplicateModelException and InvalidModelConfig exceptions. + """ + pass + + @abstractmethod + def update_model( + self, + key: str, + new_config: Union[dict, ModelConfigBase], + ) -> ModelConfigBase: + """ + Update the named model with a dictionary of attributes. + + Will fail with a + UnknownModelException if the name does not already exist. + + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model key is unknown. + """ + pass + + @abstractmethod + def del_model(self, key: str, delete_files: bool = False): + """ + Delete the named model from configuration. If delete_files + is true, then the underlying file or directory will be + deleted as well. + """ + pass + + def rename_model( + self, + key: str, + new_name: str, + ) -> ModelConfigBase: + """ + Rename the indicated model. + """ + return self.update_model(key, {"name": new_name}) + + +# implementation base class +class ModelRecordService(ModelRecordServiceBase): + """Responsible for managing models on disk and in memory.""" + + _store: ModelConfigStore = Field(description="Config record storage backend") + + @abstractmethod + def __init__(self): + """Initialize object -- abstract method.""" + pass + + def get_model( + self, + key: str, + ) -> AnyModelConfig: + """ + Retrieve the indicated model. + + :param key: Key of model config to be fetched. + + Exceptions: UnknownModelException + """ + return self._store.get_model(key) + + def model_exists( + self, + key: str, + ) -> bool: + """ + Verify that a model with the given key exists. + + Given a model key, returns True if it is a valid + identifier. + """ + return self._store.exists(key) + + def model_info(self, key: str) -> ModelConfigBase: + """ + Return configuration information about a model. + + Given a model key returns the ModelConfigBase describing it. + """ + return self._store.get_model(key) + + def list_models( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + ) -> List[ModelConfigBase]: + """ + Return a ModelConfigBase object for each model in the database. + """ + return self._store.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type) + + def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase: + """ + Return information about the model using the same format as list_models(). + + If there are more than one model that match, raises a DuplicateModelException. + If no model matches, raises an UnknownModelException + """ + model_configs = self.list_models(model_name=model_name, base_model=base_model, model_type=model_type) + if len(model_configs) > 1: + raise DuplicateModelException( + "More than one model share the same name and type: {base_model}/{model_type}/{model_name}" + ) + if len(model_configs) == 0: + raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}") + return model_configs[0] + + def all_models(self) -> List[ModelConfigBase]: + """Return a list of all the models.""" + return self.list_models() + + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: + """ + Add a model to the database. + + :param key: Unique key for the model + :param config: Model configuration record, either a dict with the + required fields or a ModelConfigBase instance. + + Can raise DuplicateModelException and InvalidModelConfig exceptions. + """ + self._store.add_model(key, config) + + def update_model( + self, + key: str, + new_config: Union[dict, ModelConfigBase], + ) -> ModelConfigBase: + """ + Update the named model with a dictionary of attributes. + + Will fail with a + UnknownModelException if the name does not already exist. + + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model key is unknown. + """ + new_info = self._store.update_model(key, new_config) + print('FIX ME!!! need to call sync_model_path() somewhere... maybe router?') + # self._loader.installer.sync_model_path(new_info.key) Maybe this goes into the router call? + return new_info + + def del_model( + self, + key: str, + ): + """ + Delete the named model from configuration. + """ + model_info = self.model_info(key) + self._store.del_model(key) + + def rename_model( + self, + key: str, + new_name: str, + ): + """ + Rename the indicated model to the new name. + + :param key: Unique key for the model. + :param new_name: New name for the model + """ + return self.update_model(key, {"name": new_name}) + + +class ModelRecordServiceSQL(ModelRecordService): + """ModelRecordService that uses Sqlite for its backend.""" + + def __init__(self, conn: sqlite3.Connection, lock: threading.Lock): + """ + Initialize a ModelRecordService that uses a SQLITE3 database backend. + + :param conn: sqlite3 Connection object + :param lock: Thread lock object + """ + self._store = ModelConfigStoreSQL(conn, lock) + + +class ModelRecordServiceFile(ModelRecordService): + """ModelRecordService that uses a YAML file for its backend.""" + + def __init__(self, models_file: Path): + """ + Initialize a ModelRecordService that uses a YAML file as the backend. + + :param models_file: Path to the YAML file backend. + """ + self._store = ModelConfigStoreYAML(models_file) diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 997867dc04..49c71f024e 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -6,11 +6,12 @@ from .model_manager import ( # noqa F401 DuplicateModelException, InvalidModelException, ModelConfigStore, - ModelInstall, - ModelLoad, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings, SubModelType, ) +from .model_manager.install import ModelInstall # noqa F401 +from .model_manager.loader import ModelLoad # noqa F401 +from .util.devices import get_precision # noqa F401 diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 8a51d75a5e..45067378c8 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -11,9 +11,9 @@ from .config import ( # noqa F401 SilenceWarnings, SubModelType, ) -from .install import ModelInstall, ModelInstallJob # noqa F401 -from .loader import ModelInfo, ModelLoad # noqa F401 -from .lora import ModelPatcher, ONNXModelPatcher # noqa F401 +# from .install import ModelInstall, ModelInstallJob # noqa F401 +# from .loader import ModelInfo, ModelLoad # noqa F401 +# from .lora import ModelPatcher, ONNXModelPatcher # noqa F401 from .models import OPENAPI_MODEL_CONFIGS, InvalidModelException, read_checkpoint_meta # noqa F401 from .probe import ModelProbe, ModelProbeInfo # noqa F401 from .search import ModelSearch # noqa F401 diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 4110f941c9..1d932c2836 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -60,6 +60,7 @@ from pydantic import Field from pydantic.networks import AnyHttpUrl from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_record_service import ModelRecordServiceBase from invokeai.backend.util import Chdir, InvokeAILogger, Logger from .config import ( @@ -319,6 +320,11 @@ class ModelInstallBase(ABC): """ pass + @abstractmethod + def sync_to_config(self): + """Synchronize models on disk to those in memory.""" + pass + @abstractmethod def scan_models_directory(self): """ @@ -367,7 +373,7 @@ class ModelInstall(ModelInstallBase): def __init__( self, - store: Optional[ModelConfigStore] = None, + store: Optional[Union[ModelConfigStore, ModelRecordServiceBase]] = None, config: Optional[InvokeAIAppConfig] = None, logger: Optional[Logger] = None, download: Optional[DownloadQueueBase] = None, @@ -381,10 +387,6 @@ class ModelInstall(ModelInstallBase): self._installed = set() self._tmpdir = None - # this step synchronizes the `models` directory with the models db - # do NOT do this automatically, but only on app startup - # self.scan_models_directory() - @property def queue(self) -> DownloadQueueBase: """Return the queue.""" @@ -609,7 +611,7 @@ class ModelInstall(ModelInstallBase): variant: Optional[str] = None, access_token: Optional[str] = None, priority: Optional[int] = 10, - ) -> DownloadJobBase: + ) -> ModelInstallJob: # Clean up a common source of error. Doesn't work with Paths. if isinstance(source, str): source = source.strip() @@ -747,6 +749,10 @@ class ModelInstall(ModelInstallBase): pass return True + def sync_to_config(self): + """Synchronize models on disk to those in memory.""" + self.scan_models_directory() + def scan_models_directory(self): """ Scan the models directory for new and missing models. diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index 07acfc97e3..1b7abc6894 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -11,11 +11,10 @@ import torch from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device +from invokeai.app.services.model_record_service import ModelRecordServiceBase from .cache import CacheStats, ModelCache from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType -from .download import DownloadEventHandler, DownloadQueueBase -from .install import ModelInstall, ModelInstallBase from .models import MODEL_CLASSES, InvalidModelException, ModelBase from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store @@ -65,12 +64,6 @@ class ModelLoadBase(ABC): """Return the ModelConfigStore object that supports this loader.""" pass - @property - @abstractmethod - def installer(self) -> ModelInstallBase: - """Return the ModelInstallBase object that supports this loader.""" - pass - @property @abstractmethod def logger(self) -> Logger: @@ -80,13 +73,7 @@ class ModelLoadBase(ABC): @property @abstractmethod def config(self) -> InvokeAIAppConfig: - """Return the config object used by this installer.""" - pass - - @property - @abstractmethod - def queue(self) -> DownloadQueueBase: - """Return the download queue object used by this object.""" + """Return the config object used by the loader.""" pass @abstractmethod @@ -105,18 +92,12 @@ class ModelLoadBase(ABC): """Return torch.fp16 or torch.fp32.""" pass - @abstractmethod - def sync_to_config(self): - """Reinitialize the store to sync in-memory and in-disk versions.""" - pass - class ModelLoad(ModelLoadBase): """Implementation of ModelLoadBase.""" _app_config: InvokeAIAppConfig - _store: ModelConfigStore - _installer: ModelInstallBase + _store: Union[ModelConfigStore, ModelRecordServiceBase] _cache: ModelCache _logger: Logger _cache_keys: dict @@ -126,7 +107,6 @@ class ModelLoad(ModelLoadBase): self, config: InvokeAIAppConfig, store: Optional[ModelConfigStore] = None, - event_handlers: List[DownloadEventHandler] = [], ): """ Initialize ModelLoad object. @@ -138,12 +118,6 @@ class ModelLoad(ModelLoadBase): self._app_config = config self._store = store self._logger = InvokeAILogger.get_logger() - self._installer = ModelInstall( - store=self._store, - logger=self._logger, - config=self._app_config, - event_handlers=event_handlers, - ) self._cache_keys = dict() self._models_file = config.model_conf_path device = torch.device(choose_torch_device()) @@ -190,11 +164,6 @@ class ModelLoad(ModelLoadBase): """Return torch.fp16 or torch.fp32.""" return self._cache.precision - @property - def installer(self) -> ModelInstallBase: - """Return the ModelInstallBase instance used by this class.""" - return self._installer - @property def logger(self) -> Logger: """Return the current logger.""" @@ -202,14 +171,9 @@ class ModelLoad(ModelLoadBase): @property def config(self) -> InvokeAIAppConfig: - """Return the config object used by the installer.""" + """Return the config object.""" return self._app_config - @property - def queue(self) -> DownloadQueueBase: - """Return the download queue object used by this object.""" - return self._installer.queue - def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo: """ Get the ModelInfo corresponding to the model with key "key". @@ -303,7 +267,3 @@ class ModelLoad(ModelLoadBase): model_path = self.resolve_model_path(model_path) return model_path, is_submodel_override - def sync_to_config(self): - """Synchronize models on disk to those in memory.""" - self._store = get_config_store(self._models_file) - self.installer.scan_models_directory() diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index c0d8c90b3c..63c7a000cd 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -84,9 +84,9 @@ class ModelMerger(object): self, model_keys: List[str], merged_model_name: str, - alpha: float = 0.5, + alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, + force: Optional[bool] = False, merge_dest_directory: Optional[Path] = None, **kwargs, ) -> ModelConfigBase: diff --git a/invokeai/backend/model_manager/storage/__init__.py b/invokeai/backend/model_manager/storage/__init__.py index 51675c3c4c..8541814648 100644 --- a/invokeai/backend/model_manager/storage/__init__.py +++ b/invokeai/backend/model_manager/storage/__init__.py @@ -10,7 +10,7 @@ from .base import ( # noqa F401 from .migrate import migrate_models_store # noqa F401 from .sql import ModelConfigStoreSQL # noqa F401 from .yaml import ModelConfigStoreYAML # noqa F401 - +from ..config import AnyModelConfig # noqa F401 def get_config_store(location: pathlib.Path) -> ModelConfigStore: """Return the type of ModelConfigStore appropriate to the path.""" diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py index d7585957aa..42cd707f2b 100644 --- a/invokeai/backend/model_manager/storage/base.py +++ b/invokeai/backend/model_manager/storage/base.py @@ -39,7 +39,7 @@ class ModelConfigStore(ABC): pass @abstractmethod - def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> None: + def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase: """ Add a model to the database. diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index 93328979ca..10602a56d5 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -53,21 +53,18 @@ from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore class ModelConfigStoreSQL(ModelConfigStore): """Implementation of the ModelConfigStore ABC using a YAML file.""" - _filename: Path _conn: sqlite3.Connection _cursor: sqlite3.Cursor _lock: threading.Lock - def __init__(self, filename: Path): + def __init__(self, conn: sqlite3.Connection, lock: threading.Lock): """Initialize ModelConfigStore object with a sqlite3 database.""" super().__init__() - self._filename = Path(filename).absolute() # don't let chdir mess us up! - self._filename.parent.mkdir(parents=True, exist_ok=True) - self._conn = sqlite3.connect(filename, check_same_thread=False) + self._conn = conn # Enable row factory to get rows as dictionaries (must be done before making the cursor!) self._conn.row_factory = sqlite3.Row self._cursor = self._conn.cursor() - self._lock = threading.Lock() + self._lock = lock try: self._lock.acquire() @@ -174,7 +171,7 @@ class ModelConfigStoreSQL(ModelConfigStore): ("version", CONFIG_FILE_VERSION), ) - def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: """ Add a model to the database. @@ -225,6 +222,8 @@ class ModelConfigStoreSQL(ModelConfigStore): finally: self._lock.release() + return self.get_model(key) + @property def version(self) -> str: """Return the version of the database schema.""" diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index 5a4a42a250..a59c934d42 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -104,7 +104,7 @@ class ModelConfigStoreYAML(ModelConfigStore): """Return version of this config file/database.""" return self._config.__metadata__.get("version") - def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None: + def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: """ Add a model to the database. @@ -127,6 +127,7 @@ class ModelConfigStoreYAML(ModelConfigStore): self._commit() finally: self._lock.release() + return self.get_model(key) def _fix_enums(self, original: dict) -> dict: """In python 3.9, omegaconf stores incorrectly stringified enums.""" diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 84ca7ee02b..4f0dad91ad 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -2,7 +2,7 @@ from __future__ import annotations import platform from contextlib import nullcontext -from typing import Union +from typing import Union, Literal import torch from packaging import version @@ -41,6 +41,11 @@ def choose_precision(device: torch.device) -> str: return "float16" return "float32" +def get_precision() -> Literal['float16', 'float32']: + device = torch.device(choose_torch_device()) + precision = choose_precision(device) if config.precision == "auto" else config.precision + assert precision in ['float16', 'float32'] + return precision def torch_dtype(device: torch.device) -> torch.dtype: if config.full_precision: