diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 880c8b2480..39220f4ba8 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -28,7 +28,7 @@ model. These are the: Hugging Face, as well as discriminating among model versions in Civitai, but can be used for arbitrary content. - * _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**) + * _ModelLoadServiceBase_ Responsible for loading a model from disk into RAM and VRAM and getting it ready for inference. @@ -41,10 +41,10 @@ The four main services can be found in * `invokeai/app/services/model_records/` * `invokeai/app/services/model_install/` * `invokeai/app/services/downloads/` -* `invokeai/app/services/model_loader/` (**under development**) +* `invokeai/app/services/model_load/` Code related to the FastAPI web API can be found in -`invokeai/app/api/routers/model_records.py`. +`invokeai/app/api/routers/model_manager_v2.py`. *** @@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but `ModelType`, `ModelFormat` and `BaseModelType` are string enums that are defined in `invokeai.backend.model_manager.config`. They are also imported by, and can be reexported from, -`invokeai.app.services.model_record_service`: +`invokeai.app.services.model_manager.model_records`: ``` -from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType +from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType ``` The `path` field can be absolute or relative. If relative, it is taken @@ -123,7 +123,7 @@ taken to be the `models_dir` directory. `variant` is an enumerated string class with values `normal`, `inpaint` and `depth`. If needed, it can be imported if needed from -either `invokeai.app.services.model_record_service` or +either `invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### ONNXSD2Config @@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or | `upcast_attention` | bool | Model requires its attention module to be upcast | The `SchedulerPredictionType` enum can be imported from either -`invokeai.app.services.model_record_service` or +`invokeai.app.services.model_records` or `invokeai.backend.model_manager.config`. ### Other config classes @@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base models. This works OK for some models, such as the IP Adapter image encoders, but is an all-or-nothing proposition. -Another issue is that the config class hierarchy is paralleled to some -extent by a `ModelBase` class hierarchy defined in -`invokeai.backend.model_manager.models.base` and its subclasses. These -are classes representing the models after they are loaded into RAM and -include runtime information such as load status and bytes used. Some -of the fields, including `name`, `model_type` and `base_model`, are -shared between `ModelConfigBase` and `ModelBase`, and this is a -potential source of confusion. - ## Reading and Writing Model Configuration Records The `ModelRecordService` provides the ability to retrieve model @@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the `InvocationContext` object: ``` -store = context.services.model_record_store +store = context.services.model_manager.store ``` or from elsewhere in the code by accessing -`ApiDependencies.invoker.services.model_record_store`. +`ApiDependencies.invoker.services.model_manager.store`. ### Creating a `ModelRecordService` @@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a `ModelRecordServiceFile` object: ``` -from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile +from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile store = ModelRecordServiceSQL.from_connection(connection, lock) store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db') @@ -252,7 +243,7 @@ So a typical startup pattern would be: ``` import sqlite3 from invokeai.app.services.thread import lock -from invokeai.app.services.model_record_service import ModelRecordServiceBase +from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() @@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False) store = ModelRecordServiceBase.open(config, db_conn, lock) ``` -_A note on simultaneous access to `invokeai.db`_: The current InvokeAI -service architecture for the image and graph databases is careful to -use a shared sqlite3 connection and a thread lock to ensure that two -threads don't attempt to access the database simultaneously. However, -the default `sqlite3` library used by Python reports using -**Serialized** mode, which allows multiple threads to access the -database simultaneously using multiple database connections (see -https://www.sqlite.org/threadsafe.html and -https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore -it should be safe to allow the record service to open its own SQLite -database connection. Opening a model record service should then be as -simple as `ModelRecordServiceBase.open(config)`. - ### Fetching a Model's Configuration from `ModelRecordServiceBase` Configurations can be retrieved in several ways. @@ -1465,7 +1443,7 @@ create alternative instances if you wish. ### Creating a ModelLoadService object The class is defined in -`invokeai.app.services.model_loader_service`. It is initialized with +`invokeai.app.services.model_load`. It is initialized with an InvokeAIAppConfig object, from which it gets configuration information such as the user's desired GPU and precision, and with a previously-created `ModelRecordServiceBase` object, from which it @@ -1475,8 +1453,8 @@ Here is a typical initialization pattern: ``` from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_record_service import ModelRecordServiceBase -from invokeai.app.services.model_loader_service import ModelLoadService +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.app.services.model_load import ModelLoadService config = InvokeAIAppConfig.get_config() store = ModelRecordServiceBase.open(config) @@ -1487,14 +1465,11 @@ Note that we are relying on the contents of the application configuration to choose the implementation of `ModelRecordServiceBase`. -### get_model(key, [submodel_type], [context]) -> ModelInfo: +### load_model_by_key(key, [submodel_type], [context]) -> LoadedModel -*** TO DO: change to get_model(key, context=None, **kwargs) - -The `get_model()` method, like its similarly-named cousin in -`ModelRecordService`, receives the unique key that identifies the -model. It loads the model into memory, gets the model ready for use, -and returns a `ModelInfo` object. +The `load_model_by_key()` method receives the unique key that +identifies the model. It loads the model into memory, gets the model +ready for use, and returns a `LoadedModel` object. The optional second argument, `subtype` is a `SubModelType` string enum, such as "vae". It is mandatory when used with a main model, and @@ -1504,46 +1479,64 @@ The optional third argument, `context` can be provided by an invocation to trigger model load event reporting. See below for details. -The returned `ModelInfo` object shares some fields in common with -`ModelConfigBase`, but is otherwise a completely different beast: +The returned `LoadedModel` object contains a copy of the configuration +record returned by the model record `get_model()` method, as well as +the in-memory loaded model: -| **Field Name** | **Type** | **Description** | + +| **Attribute Name** | **Type** | **Description** | |----------------|-----------------|------------------| -| `key` | str | The model key derived from the ModelRecordService database | -| `name` | str | Name of this model | -| `base_model` | BaseModelType | Base model for this model | -| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)| -| `location` | Path or str | Location of the model on the filesystem | -| `precision` | torch.dtype | The torch.precision to use for inference | -| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use | +| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | +| `model` | AnyModel | The instantiated model (details below) | +| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | -The types for `ModelInfo` and `SubModelType` can be imported from -`invokeai.app.services.model_loader_service`. +Because the loader can return multiple model types, it is typed to +return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. -To use the model, you use the `ModelInfo` as a context manager using -the following pattern: + +`LoadedModel` acts as a context manager. The context loads the model +into the execution device (e.g. VRAM on CUDA systems), locks the model +in the execution device for the duration of the context, and returns +the model. Use it like this: ``` -model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) with model_info as vae: image = vae.decode(latents)[0] ``` -The `vae` model will stay locked in the GPU during the period of time -it is in the context manager's scope. +`get_model_by_key()` may raise any of the following exceptions: -`get_model()` may raise any of the following exceptions: - -- `UnknownModelException` -- key not in database -- `ModelNotFoundException` -- key in database but model not found at path -- `InvalidModelException` -- the model is guilty of a variety of sins +- `UnknownModelException` -- key not in database +- `ModelNotFoundException` -- key in database but model not found at path +- `NotImplementedException` -- the loader doesn't know how to load this type of model -** TO DO: ** Resolve discrepancy between ModelInfo.location and -ModelConfig.path. +### load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel + +This is similar to `load_model_by_key`, but instead it accepts the +combination of the model's name, type and base, which it passes to the +model record config store for retrieval. If successful, this method +returns a `LoadedModel`. It can raise the following exceptions: + +``` +UnknownModelException -- model with these attributes not known +NotImplementedException -- the loader doesn't know how to load this type of model +ValueError -- more than one model matches this combination of base/type/name +``` + +### load_model_by_config(config, [submodel], [context]) -> LoadedModel + +This method takes an `AnyModelConfig` returned by +ModelRecordService.get_model() and returns the corresponding loaded +model. It may raise a `NotImplementedException`. ### Emitting model loading events -When the `context` argument is passed to `get_model()`, it will +When the `context` argument is passed to `load_model_*()`, it will retrieve the invocation event bus from the passed `InvocationContext` object to emit events on the invocation bus. The two events are "model_load_started" and "model_load_completed". Both carry the @@ -1563,3 +1556,97 @@ payload=dict( ) ``` +### Adding Model Loaders + +Model loaders are small classes that inherit from the `ModelLoader` +base class. They typically implement one method `_load_model()` whose +signature is: + +``` +def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, +) -> AnyModel: +``` + +`_load_model()` will be passed the path to the model on disk, an +optional repository variant (used by the diffusers loaders to select, +e.g. the `fp16` variant, and an optional submodel_type for main and +onnx models. + +To install a new loader, place it in +`invokeai/backend/model_manager/load/model_loaders`. Inherit from +`ModelLoader` and use the `@AnyModelLoader.register()` decorator to +indicate what type of models the loader can handle. + +Here is a complete example from `generic_diffusers.py`, which is able +to load several different diffusers types: + +``` +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) +class GenericDiffusersLoader(ModelLoader): + """Class to load simple diffusers models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + model_class = self._get_hf_load_class(model_path) + if submodel_type is not None: + raise Exception(f"There are no submodels in models of type {model_class}") + variant = model_variant.value if model_variant else None + result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore + return result +``` + +Note that a loader can register itself to handle several different +model types. An exception will be raised if more than one loader tries +to register the same model type. + +#### Conversion + +Some models require conversion to diffusers format before they can be +loaded. These loaders should override two additional methods: + +``` +_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool +_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: +``` + +The first method accepts the model configuration, the path to where +the unmodified model is currently installed, and a proposed +destination for the converted model. This method returns True if the +model needs to be converted. It typically does this by comparing the +last modification time of the original model file to the modification +time of the converted model. In some cases you will also want to check +the modification date of the configuration record, in the event that +the user has changed something like the scheduler prediction type that +will require the model to be re-converted. See `controlnet.py` for an +example of this logic. + +The second method accepts the model configuration, the path to the +original model on disk, and the desired output path for the converted +model. It does whatever it needs to do to get the model into diffusers +format, and returns the Path of the resulting model. (The path should +ordinarily be the same as `output_path`.) + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a54c8c1c3e..8a809cef37 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,9 +4,6 @@ from logging import Logger from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory from invokeai.app.services.shared.sqlite.sqlite_util import init_db -from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache -from invokeai.backend.model_manager.load.model_cache import ModelCache -from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from invokeai.version.invokeai_version import __version__ @@ -27,9 +24,7 @@ from ..services.invocation_stats.invocation_stats_default import InvocationStats from ..services.invoker import Invoker from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage -from ..services.model_install import ModelInstallService from ..services.model_manager.model_manager_default import ModelManagerService -from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue @@ -87,26 +82,10 @@ class ApiDependencies: images = ImageService() invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) - model_loader = AnyModelLoader( - app_config=config, - logger=logger, - ram_cache=ModelCache( - max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger - ), - convert_cache=ModelConvertCache( - cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size - ), - ) - model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader) download_queue_service = DownloadQueueService(event_bus=events) - model_install_service = ModelInstallService( - app_config=config, - record_store=model_record_service, - download_queue=download_queue_service, - metadata_store=ModelMetadataStore(db=db), - event_bus=events, + model_manager = ModelManagerService.build_model_manager( + app_config=configuration, db=db, download_queue=download_queue_service, events=events ) - model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove names = SimpleNameService() performance_statistics = InvocationStatsService() processor = DefaultInvocationProcessor() @@ -131,9 +110,7 @@ class ApiDependencies: latents=latents, logger=logger, model_manager=model_manager, - model_records=model_record_service, download_queue=download_queue_service, - model_install=model_install_service, names=names, performance_statistics=performance_statistics, processor=processor, diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_manager_v2.py similarity index 86% rename from invokeai/app/api/routers/model_records.py rename to invokeai/app/api/routers/model_manager_v2.py index f9a3e40898..4fc785e4f7 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_manager_v2.py @@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from ..dependencies import ApiDependencies -model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"]) +model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"]) class ModelsList(BaseModel): @@ -52,7 +52,7 @@ class ModelTagSet(BaseModel): tags: Set[str] -@model_records_router.get( +@model_manager_v2_router.get( "/", operation_id="list_model_records", ) @@ -65,7 +65,7 @@ async def list_model_records( ), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store found_models: list[AnyModelConfig] = [] if base_models: for base_model in base_models: @@ -81,7 +81,7 @@ async def list_model_records( return ModelsList(models=found_models) -@model_records_router.get( +@model_manager_v2_router.get( "/i/{key}", operation_id="get_model_record", responses={ @@ -94,24 +94,27 @@ async def get_model_record( key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store try: - return record_store.get_model(key) + config: AnyModelConfig = record_store.get_model(key) + return config except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.get("/meta", operation_id="list_model_summary") +@model_manager_v2_router.get("/meta", operation_id="list_model_summary") async def list_model_summary( page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), ) -> PaginatedResults[ModelSummary]: """Gets a page of model summary data.""" - return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by) + record_store = ApiDependencies.invoker.services.model_manager.store + results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) + return results -@model_records_router.get( +@model_manager_v2_router.get( "/meta/i/{key}", operation_id="get_model_metadata", responses={ @@ -124,24 +127,25 @@ async def get_model_metadata( key: str = Path(description="Key of the model repo metadata to fetch."), ) -> Optional[AnyModelRepoMetadata]: """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_records - result = record_store.get_metadata(key) + record_store = ApiDependencies.invoker.services.model_manager.store + result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) if not result: raise HTTPException(status_code=404, detail="No metadata for a model with this key") return result -@model_records_router.get( +@model_manager_v2_router.get( "/tags", operation_id="list_tags", ) async def list_tags() -> Set[str]: """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_records - return record_store.list_tags() + record_store = ApiDependencies.invoker.services.model_manager.store + result: Set[str] = record_store.list_tags() + return result -@model_records_router.get( +@model_manager_v2_router.get( "/tags/search", operation_id="search_by_metadata_tags", ) @@ -149,12 +153,12 @@ async def search_by_metadata_tags( tags: Set[str] = Query(default=None, description="Tags to search for"), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store results = record_store.search_by_metadata_tag(tags) return ModelsList(models=results) -@model_records_router.patch( +@model_manager_v2_router.patch( "/i/{key}", operation_id="update_model_record", responses={ @@ -172,9 +176,9 @@ async def update_model_record( ) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store try: - model_response = record_store.update_model(key, config=info) + model_response: AnyModelConfig = record_store.update_model(key, config=info) logger.info(f"Updated model: {key}") except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) @@ -184,7 +188,7 @@ async def update_model_record( return model_response -@model_records_router.delete( +@model_manager_v2_router.delete( "/i/{key}", operation_id="del_model_record", responses={ @@ -205,7 +209,7 @@ async def del_model_record( logger = ApiDependencies.invoker.services.logger try: - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install installer.delete(key) logger.info(f"Deleted model: {key}") return Response(status_code=204) @@ -214,7 +218,7 @@ async def del_model_record( raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.post( +@model_manager_v2_router.post( "/i/", operation_id="add_model_record", responses={ @@ -229,7 +233,7 @@ async def add_model_record( ) -> AnyModelConfig: """Add a model using the configuration information appropriate for its type.""" logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_records + record_store = ApiDependencies.invoker.services.model_manager.store if config.key == "": config.key = sha1(randbytes(100)).hexdigest() logger.info(f"Created model {config.key} for {config.name}") @@ -243,10 +247,11 @@ async def add_model_record( raise HTTPException(status_code=415) # now fetch it out - return record_store.get_model(config.key) + result: AnyModelConfig = record_store.get_model(config.key) + return result -@model_records_router.post( +@model_manager_v2_router.post( "/import", operation_id="import_model_record", responses={ @@ -322,7 +327,7 @@ async def import_model( logger = ApiDependencies.invoker.services.logger try: - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install result: ModelInstallJob = installer.import_model( source=source, config=config, @@ -340,17 +345,17 @@ async def import_model( return result -@model_records_router.get( +@model_manager_v2_router.get( "/import", operation_id="list_model_install_jobs", ) async def list_model_install_jobs() -> List[ModelInstallJob]: """Return list of model install jobs.""" - jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs() + jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() return jobs -@model_records_router.get( +@model_manager_v2_router.get( "/import/{id}", operation_id="get_model_install_job", responses={ @@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: """Return model install job corresponding to the given source.""" try: - return ApiDependencies.invoker.services.model_install.get_job_by_id(id) + result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) + return result except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) -@model_records_router.delete( +@model_manager_v2_router.delete( "/import/{id}", operation_id="cancel_model_install_job", responses={ @@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id")) ) async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: """Cancel the model install job(s) corresponding to the given job ID.""" - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install try: job = installer.get_job_by_id(id) except ValueError as e: @@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job installer.cancel_job(job) -@model_records_router.patch( +@model_manager_v2_router.patch( "/import", operation_id="prune_model_install_jobs", responses={ @@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job ) async def prune_model_install_jobs() -> Response: """Prune all completed and errored jobs from the install job list.""" - ApiDependencies.invoker.services.model_install.prune_jobs() + ApiDependencies.invoker.services.model_manager.install.prune_jobs() return Response(status_code=204) -@model_records_router.patch( +@model_manager_v2_router.patch( "/sync", operation_id="sync_models_to_config", responses={ @@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response: Model files without a corresponding record in the database are added. Orphan records without a models file are deleted. """ - ApiDependencies.invoker.services.model_install.sync_to_config() + ApiDependencies.invoker.services.model_manager.install.sync_to_config() return Response(status_code=204) -@model_records_router.put( +@model_manager_v2_router.put( "/merge", operation_id="merge", ) @@ -451,7 +457,7 @@ async def merge( try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - installer = ApiDependencies.invoker.services.model_install + installer = ApiDependencies.invoker.services.model_manager.install merger = ModelMerger(installer) model_names = [installer.record_store.get_model(x).name for x in keys] response = merger.merge_diffusion_models_and_save( diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8f83820cf8..0aa7aa0ecb 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -8,8 +8,7 @@ from fastapi.routing import APIRouter from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from starlette.exceptions import HTTPException -from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import MergeInterpolationMethod +from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType from invokeai.backend.model_management.models import ( OPENAPI_MODEL_CONFIGS, InvalidModelException, diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6294083d0e..6e47e9e30d 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -47,7 +47,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c boards, download_queue, images, - model_records, + model_manager_v2, models, session_queue, sessions, @@ -115,8 +115,7 @@ async def shutdown_event() -> None: app.include_router(sessions.session_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api") -app.include_router(models.models_router, prefix="/api") -app.include_router(model_records.model_records_router, prefix="/api") +app.include_router(model_manager_v2.model_manager_v2_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(images.images_router, prefix="/api") app.include_router(boards.boards_router, prefix="/api") diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 12dcd9f930..df1c3b3245 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -4,6 +4,7 @@ from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +from transformers import CLIPTokenizer import invokeai.backend.util.logging as logger from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput @@ -70,18 +71,18 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **self.clip.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) assert isinstance(lora_info.model, LoRAModelRaw) @@ -95,7 +96,7 @@ class CompelInvocation(BaseInvocation): for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.services.model_records.load_model( + loaded_model = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), context=context, ).model @@ -171,11 +172,11 @@ class SDXLPromptInvocationBase: lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **clip_field.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **clip_field.text_encoder.model_dump(), context=context, ) @@ -203,7 +204,7 @@ class SDXLPromptInvocationBase: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context ) lora_model = lora_info.model @@ -218,7 +219,7 @@ class SDXLPromptInvocationBase: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.services.model_records.load_model_by_attr( + ti_model = context.services.model_manager.load.load_model_by_attr( model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion, @@ -465,9 +466,9 @@ class ClipSkipInvocation(BaseInvocation): def get_max_token_count( - tokenizer, + tokenizer: CLIPTokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False, + truncate_if_too_long: bool = False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -479,7 +480,9 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: +def get_tokens_for_prompt_object( + tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -492,24 +495,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens = tokenizer.tokenize(text) + tokens: List[str] = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: + assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -549,7 +557,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text: str, + tokenizer: CLIPTokenizer, + display_label: Optional[str] = None, + truncate_if_too_long: Optional[bool] = False, +) -> None: """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a621f9fe71..b0419f424f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,15 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import Iterator, List, Literal, Optional, Tuple, Union +from typing import Any, Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np +import numpy.typing as npt import torch import torchvision.transforms as T -from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel +from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -18,8 +20,10 @@ from diffusers.models.attention_processor import ( LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler +from PIL import Image from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize @@ -38,9 +42,10 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.embeddings.lora import LoRAModelRaw from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_manager import AnyModel, BaseModelType +from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -76,7 +81,9 @@ if choose_torch_device() == torch.device("mps"): DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # FIXME: "Invalid type alias" +SAMPLER_NAME_VALUES = Literal[ + tuple(SCHEDULER_MAP.keys()) +] # FIXME: "Invalid type alias". This defeats static type checking. # HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to # be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale @@ -130,10 +137,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ui_order=4, ) - def prep_mask_tensor(self, mask_image): + def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor: if mask_image.mode != "L": mask_image = mask_image.convert("L") - mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) + mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0) # if shape is not None: @@ -144,24 +151,24 @@ class CreateDenoiseMaskInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: image = context.services.images.get_pil_image(self.image.image_name) - image = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image.dim() == 3: - image = image.unsqueeze(0) + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) else: - image = None + image_tensor = None mask = self.prep_mask_tensor( context.services.images.get_pil_image(self.mask.image_name), ) - if image is not None: - vae_info = context.services.model_records.load_model( + if image_tensor is not None: + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) - img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) - masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) + img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) + masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) @@ -188,7 +195,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_records.load_model( + orig_scheduler_info = context.services.model_manager.load.load_model_by_key( **scheduler_info.model_dump(), context=context, ) @@ -199,7 +206,7 @@ def get_scheduler( scheduler_config = scheduler_config["_backup"] scheduler_config = { **scheduler_config, - **scheduler_extra_config, + **scheduler_extra_config, # FIXME "_backup": scheduler_config, } @@ -212,6 +219,7 @@ def get_scheduler( # hack copied over from generate.py if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False + assert isinstance(scheduler, Scheduler) return scheduler @@ -295,7 +303,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ) @field_validator("cfg_scale") - def ge_one(cls, v): + def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]: """validate that all cfg_scale values are >= 1""" if isinstance(v, list): for i in v: @@ -325,9 +333,9 @@ class DenoiseLatentsInvocation(BaseInvocation): def get_conditioning_data( self, context: InvocationContext, - scheduler, - unet, - seed, + scheduler: Scheduler, + unet: UNet2DConditionModel, + seed: int, ) -> ConditioningData: positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) @@ -350,7 +358,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ), ) - conditioning_data = conditioning_data.add_scheduler_args_if_applicable( + conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME scheduler, # for ddim scheduler eta=0.0, # ddim_eta @@ -362,8 +370,8 @@ class DenoiseLatentsInvocation(BaseInvocation): def create_pipeline( self, - unet, - scheduler, + unet: UNet2DConditionModel, + scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( @@ -374,10 +382,10 @@ class DenoiseLatentsInvocation(BaseInvocation): class FakeVae: class FakeVaeConfig: - def __init__(self): + def __init__(self) -> None: self.block_out_channels = [0] - def __init__(self): + def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( @@ -394,11 +402,11 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_control_data( self, context: InvocationContext, - control_input: Union[ControlField, List[ControlField]], + control_input: Optional[Union[ControlField, List[ControlField]]], latents_shape: List[int], exit_stack: ExitStack, do_classifier_free_guidance: bool = True, - ) -> List[ControlNetData]: + ) -> Optional[List[ControlNetData]]: # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR @@ -421,7 +429,7 @@ class DenoiseLatentsInvocation(BaseInvocation): controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_records.load_model( + context.services.model_manager.load.load_model_by_key( key=control_info.control_model.key, context=context, ) @@ -487,23 +495,25 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_records.load_model( + context.services.model_manager.load.load_model_by_key( key=single_ip_adapter.ip_adapter_model.key, context=context, ) ) - image_encoder_model_info = context.services.model_records.load_model( + image_encoder_model_info = context.services.model_manager.load.load_model_by_key( key=single_ip_adapter.image_encoder_model.key, context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_images = single_ip_adapter.image - if not isinstance(single_ipa_images, list): - single_ipa_images = [single_ipa_images] + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [ + context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields + ] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -547,21 +557,19 @@ class DenoiseLatentsInvocation(BaseInvocation): t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_records.load_model( + t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( key=t2i_adapter_field.t2i_adapter_model.key, context=context, ) image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: + if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1: max_unet_downscale = 8 - elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL: + elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL: max_unet_downscale = 4 else: - raise ValueError( - f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'." - ) + raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.") t2i_adapter_model: T2IAdapter with t2i_adapter_model_info as t2i_adapter_model: @@ -609,7 +617,15 @@ class DenoiseLatentsInvocation(BaseInvocation): # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps - def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): + def init_scheduler( + self, + scheduler: Union[Scheduler, ConfigMixin], + device: torch.device, + steps: int, + denoising_start: float, + denoising_end: float, + ) -> Tuple[int, List[int], int]: + assert isinstance(scheduler, ConfigMixin) if scheduler.config.get("cpu_only", False): scheduler.set_timesteps(steps, device="cpu") timesteps = scheduler.timesteps.to(device=device) @@ -621,11 +637,11 @@ class DenoiseLatentsInvocation(BaseInvocation): _timesteps = timesteps[:: scheduler.order] # get start timestep index - t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start))) + t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start))) t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps))) # get end timestep index - t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end))) + t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end))) t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:]))) # apply order to indexes @@ -638,7 +654,9 @@ class DenoiseLatentsInvocation(BaseInvocation): return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask( + self, context: InvocationContext, latents: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if self.denoise_mask is None: return None, None @@ -691,12 +709,15 @@ class DenoiseLatentsInvocation(BaseInvocation): graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.services.model_manager.store.get_model(**self.unet.unet.model_dump()) - def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: + def step_callback(state: PipelineIntermediateState) -> None: + self.dispatch_progress(context, source_node_id, state, unet_config.base) + + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_records.load_model( + lora_info = context.services.model_manager.load.load_model_by_key( **lora.model_dump(exclude={"weight"}), context=context, ) @@ -704,7 +725,7 @@ class DenoiseLatentsInvocation(BaseInvocation): del lora_info return - unet_info = context.services.model_records.load_model( + unet_info = context.services.model_manager.load.load_model_by_key( **self.unet.unet.model_dump(), context=context, ) @@ -815,7 +836,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -1010,8 +1031,9 @@ class ImageToLatentsInvocation(BaseInvocation): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @staticmethod - def vae_encode(vae_info, upcast, tiled, image_tensor): + def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: with vae_info as vae: + assert isinstance(vae, torch.nn.Module) orig_dtype = vae.dtype if upcast: vae.to(dtype=torch.float32) @@ -1057,7 +1079,7 @@ class ImageToLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.services.images.get_pil_image(self.image.image_name) - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), context=context, ) @@ -1076,14 +1098,19 @@ class ImageToLatentsInvocation(BaseInvocation): @singledispatchmethod @staticmethod def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + assert isinstance(vae, torch.nn.Module) image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents: torch.Tensor = image_tensor_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! return latents @_encode_to_tensor.register @staticmethod def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: - return vae.encode(image_tensor).latents + assert isinstance(vae, torch.nn.Module) + latents: torch.FloatTensor = vae.encode(image_tensor).latents + return latents @invocation( @@ -1116,7 +1143,12 @@ class BlendLatentsInvocation(BaseInvocation): # TODO: device = choose_torch_device() - def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + def slerp( + t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? + v0: Union[torch.Tensor, npt.NDArray[Any]], + v1: Union[torch.Tensor, npt.NDArray[Any]], + DOT_THRESHOLD: float = 0.9995, + ) -> Union[torch.Tensor, npt.NDArray[Any]]: """ Spherical linear interpolation Args: @@ -1149,12 +1181,16 @@ class BlendLatentsInvocation(BaseInvocation): v2 = s0 * v0 + s1 * v1 if inputs_are_torch: - v2 = torch.from_numpy(v2).to(device) - - return v2 + v2_torch: torch.Tensor = torch.from_numpy(v2).to(device) + return v2_torch + else: + assert isinstance(v2, np.ndarray) + return v2 # blend - blended_latents = slerp(self.alpha, latents_a, latents_b) + bl = slerp(self.alpha, latents_a, latents_b) + assert isinstance(bl, torch.Tensor) + blended_latents: torch.Tensor = bl # for type checking convenience # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") @@ -1250,15 +1286,19 @@ class IdealSizeInvocation(BaseInvocation): description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)", ) - def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR): + def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]: return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: + unet_config = context.services.model_manager.load.load_model_by_key( + **self.unet.unet.model_dump(), + context=context, + ) aspect = self.width / self.height - dimension = 512 - if self.unet.unet.base_model == BaseModelType.StableDiffusion2: + dimension: float = 512 + if unet_config.base == BaseModelType.StableDiffusion2: dimension = 768 - elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL: + elif unet_config.base == BaseModelType.StableDiffusionXL: dimension = 1024 dimension = dimension * self.multiplier min_dimension = math.floor(dimension * 0.5) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e0e61ea26c..739cd02374 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -20,7 +20,7 @@ from .baseinvocation import ( class ModelInfo(BaseModel): - key: str = Field(description="Info to load submodel") + key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 5d39a3d7e7..118e48f89e 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -15,8 +15,8 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import ModelType, SubModelType from invokeai.backend.embeddings.model_patcher import ONNXModelPatcher +from invokeai.backend.model_manager import ModelType, SubModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util import choose_torch_device @@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_records.load_model( + tokenizer_info = context.services.model_manager.load.load_model_by_key( **self.clip.tokenizer.model_dump(), ) - text_encoder_info = context.services.model_records.load_model( + text_encoder_info = context.services.model_manager.load.load_model_by_key( **self.clip.text_encoder.model_dump(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: loras = [ ( - context.services.model_records.load_model(**lora.model_dump(exclude={"weight"})).model, + context.services.model_manager.load.load_model_by_key(**lora.model_dump(exclude={"weight"})).model, lora.weight, ) for lora in self.clip.loras @@ -84,7 +84,7 @@ class ONNXPromptInvocation(BaseInvocation): ti_list.append( ( name, - context.services.model_records.load_model_by_attr( + context.services.model_manager.load.load_model_by_attr( model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion, @@ -257,13 +257,13 @@ class ONNXTextToLatentsInvocation(BaseInvocation): eta=0.0, ) - unet_info = context.services.model_records.load_model(**self.unet.unet.model_dump()) + unet_info = context.services.model_manager.load.load_model_by_key(**self.unet.unet.model_dump()) with unet_info as unet: # , ExitStack() as stack: # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [ ( - context.services.model_records.load_model(**lora.model_dump(exclude={"weight"})).model, + context.services.model_manager.load.load_model_by_key(**lora.model_dump(exclude={"weight"})).model, lora.weight, ) for lora in self.unet.loras @@ -346,7 +346,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): if self.vae.vae.submodel != SubModelType.VaeDecoder: raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}") - vae_info = context.services.model_records.load_model( + vae_info = context.services.model_manager.load.load_model_by_key( **self.vae.vae.model_dump(), ) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9..09c9b7f3ca 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -368,7 +368,7 @@ class LatentsCollectionInvocation(BaseInvocation): return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): +def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> LatentsOutput: return LatentsOutput( latents=LatentsField(latents_name=latents_name, seed=seed), width=latents.size()[3] * 8, diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 11a4de99d6..aa3322a9a0 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -22,9 +22,7 @@ if TYPE_CHECKING: from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .item_storage.item_storage_base import ItemStorageABC from .latents_storage.latents_storage_base import LatentsStorageBase - from .model_install import ModelInstallServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase - from .model_records import ModelRecordServiceBase from .names.names_base import NameServiceBase from .session_processor.session_processor_base import SessionProcessorBase from .session_queue.session_queue_base import SessionQueueBase @@ -50,9 +48,7 @@ class InvocationServices: latents: "LatentsStorageBase" logger: "Logger" model_manager: "ModelManagerServiceBase" - model_records: "ModelRecordServiceBase" download_queue: "DownloadQueueServiceBase" - model_install: "ModelInstallServiceBase" processor: "InvocationProcessorABC" performance_statistics: "InvocationStatsServiceBase" queue: "InvocationQueueABC" @@ -78,9 +74,7 @@ class InvocationServices: latents: "LatentsStorageBase", logger: "Logger", model_manager: "ModelManagerServiceBase", - model_records: "ModelRecordServiceBase", download_queue: "DownloadQueueServiceBase", - model_install: "ModelInstallServiceBase", processor: "InvocationProcessorABC", performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", @@ -104,9 +98,7 @@ class InvocationServices: self.latents = latents self.logger = logger self.model_manager = model_manager - self.model_records = model_records self.download_queue = download_queue - self.model_install = model_install self.processor = processor self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 8883ebe295..87f3dd9d06 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -43,8 +43,10 @@ class InvocationStatsService(InvocationStatsServiceBase): @contextmanager def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + # This is to handle case of the model manager not being initialized, which happens + # during some tests. services = self._invoker.services - if services.model_records is None or services.model_records.loader is None: + if services.model_manager is None or services.model_manager.load is None: yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. @@ -60,9 +62,8 @@ class InvocationStatsService(InvocationStatsServiceBase): if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - # TO DO [LS]: clean up loader service - shouldn't be an attribute of model records - assert services.model_records.loader is not None - services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id] + assert services.model_manager.load is not None + services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/latents_storage/latents_storage_base.py index 95a0e3e748..2597259126 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/latents_storage/latents_storage_base.py @@ -5,7 +5,7 @@ from typing import Callable, Union import torch -from ..compel import ConditioningFieldData +from invokeai.app.invocations.compel import ConditioningFieldData class LatentsStorageBase(ABC): diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py index ba6dbd3a28..cc94a25e5a 100644 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ b/invokeai/app/services/latents_storage/latents_storage_disk.py @@ -5,9 +5,9 @@ from typing import Union import torch +from invokeai.app.invocations.compel import ConditioningFieldData from invokeai.app.services.invoker import Invoker -from ..compel import ConditioningFieldData from .latents_storage_base import LatentsStorageBase diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py index 1edda736a4..3a0322011d 100644 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py @@ -5,9 +5,9 @@ from typing import Dict, Optional, Union import torch +from invokeai.app.invocations.compel import ConditioningFieldData from invokeai.app.services.invoker import Invoker -from ..compel import ConditioningFieldData from .latents_storage_base import LatentsStorageBase diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 7228806e80..f298d98ce6 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel @@ -12,11 +13,60 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's key, load it and return the LoadedModel object.""" + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. + + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ pass @abstractmethod - def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's configuration, load it and return the LoadedModel object.""" + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with these attributes not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 80e2fe161d..67107cada6 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -3,12 +3,14 @@ from typing import Optional +from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, SubModelType +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase -from invokeai.backend.model_manager.load.ram_cache import ModelCacheBase +from invokeai.backend.model_manager.load.model_cache import ModelCacheBase from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -21,7 +23,7 @@ class ModelLoadService(ModelLoadServiceBase): self, app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, - ram_cache: Optional[ModelCacheBase] = None, + ram_cache: Optional[ModelCacheBase[AnyModel]] = None, convert_cache: Optional[ModelConvertCacheBase] = None, ): """Initialize the model load service.""" @@ -44,11 +46,104 @@ class ModelLoadService(ModelLoadServiceBase): ), ) - def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's key, load it and return the LoadedModel object.""" - config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type) + def load_model_by_key( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's key, load it and return the LoadedModel object. - def load_model_by_config(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: - """Given a model's configuration, load it and return the LoadedModel object.""" - return self._any_loader.load_model(config, submodel_type) + :param key: Key of model config to be fetched. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + config = self._store.get_model(key) + return self.load_model_by_config(config, submodel_type, context) + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self._store.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model_by_key(configs[0].key, submodel) + + def load_model_by_config( + self, + model_config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Given a model's configuration, load it and return the LoadedModel object. + + :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting + """ + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + loaded_model = self._any_loader.load_model(model_config, submodel_type) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c339b97617..1116c82ff1 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -2,24 +2,26 @@ from abc import ABC, abstractmethod -from pydantic import BaseModel, Field from typing_extensions import Self +from invokeai.app.services.invoker import Invoker + from ..config import InvokeAIAppConfig -from ..events.events_base import EventServiceBase from ..download import DownloadQueueServiceBase +from ..events.events_base import EventServiceBase from ..model_install import ModelInstallServiceBase from ..model_load import ModelLoadServiceBase from ..model_records import ModelRecordServiceBase from ..shared.sqlite.sqlite_database import SqliteDatabase -class ModelManagerServiceBase(BaseModel, ABC): +class ModelManagerServiceBase(ABC): """Abstract base class for the model manager service.""" - store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") - install: ModelInstallServiceBase = Field(description="An instance of the model install service.") - load: ModelLoadServiceBase = Field(description="An instance of the model load service.") + # attributes: + # store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.") + # install: ModelInstallServiceBase = Field(description="An instance of the model install service.") + # load: ModelLoadServiceBase = Field(description="An instance of the model load service.") @classmethod @abstractmethod @@ -37,3 +39,29 @@ class ModelManagerServiceBase(BaseModel, ABC): method simplifies the construction considerably. """ pass + + @property + @abstractmethod + def store(self) -> ModelRecordServiceBase: + """Return the ModelRecordServiceBase used to store and retrieve configuration records.""" + pass + + @property + @abstractmethod + def load(self) -> ModelLoadServiceBase: + """Return the ModelLoadServiceBase used to load models from their configuration records.""" + pass + + @property + @abstractmethod + def install(self) -> ModelInstallServiceBase: + """Return the ModelInstallServiceBase used to download and manipulate model files.""" + pass + + @abstractmethod + def start(self, invoker: Invoker) -> None: + pass + + @abstractmethod + def stop(self, invoker: Invoker) -> None: + pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index ad0fd66dbb..028d4af615 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -3,6 +3,7 @@ from typing_extensions import Self +from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger @@ -10,9 +11,9 @@ from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig from ..download import DownloadQueueServiceBase from ..events.events_base import EventServiceBase -from ..model_install import ModelInstallService -from ..model_load import ModelLoadService -from ..model_records import ModelRecordServiceSQL +from ..model_install import ModelInstallService, ModelInstallServiceBase +from ..model_load import ModelLoadService, ModelLoadServiceBase +from ..model_records import ModelRecordServiceBase, ModelRecordServiceSQL from ..shared.sqlite.sqlite_database import SqliteDatabase from .model_manager_base import ModelManagerServiceBase @@ -27,6 +28,38 @@ class ModelManagerService(ModelManagerServiceBase): model_manager.load -- Routines to load models into memory. """ + def __init__( + self, + store: ModelRecordServiceBase, + install: ModelInstallServiceBase, + load: ModelLoadServiceBase, + ): + self._store = store + self._install = install + self._load = load + + @property + def store(self) -> ModelRecordServiceBase: + return self._store + + @property + def install(self) -> ModelInstallServiceBase: + return self._install + + @property + def load(self) -> ModelLoadServiceBase: + return self._load + + def start(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "start"): + service.start(invoker) + + def stop(self, invoker: Invoker) -> None: + for service in [self._store, self._install, self._load]: + if hasattr(service, "stop"): + service.stop(invoker) + @classmethod def build_model_manager( cls, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index e00dd4169d..e2e98c7e89 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,15 +10,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union from pydantic import BaseModel, Field -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager import ( AnyModelConfig, BaseModelType, - LoadedModel, ModelFormat, ModelType, - SubModelType, ) from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -111,52 +108,6 @@ class ModelRecordServiceBase(ABC): """ pass - @abstractmethod - def load_model( - self, - key: str, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch - :param context: Invocation context, used for event issuing. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - pass - - @abstractmethod - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Key of model config to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - pass - @property @abstractmethod def metadata_store(self) -> ModelMetadataStore: diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 28a77b1b1a..f48175351d 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -46,8 +46,6 @@ from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union -from invokeai.app.invocations.baseinvocation import InvocationContext -from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -55,9 +53,8 @@ from invokeai.backend.model_manager.config import ( ModelConfigFactory, ModelFormat, ModelType, - SubModelType, ) -from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel +from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException from ..shared.sqlite.sqlite_database import SqliteDatabase @@ -220,74 +217,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model - def load_model( - self, - key: str, - submodel: Optional[SubModelType], - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - :param key: Key of model config to be fetched. - :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - """ - if not self._loader: - raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") - # we can emit model loading events if we are executing with access to the invocation context - - model_config = self.get_model(key) - if context: - self._emit_load_event( - context=context, - model_config=model_config, - ) - loaded_model = self._loader.load_model(model_config, submodel) - if context: - self._emit_load_event( - context=context, - model_config=model_config, - loaded=True, - ) - return loaded_model - - def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, - ) -> LoadedModel: - """ - Load the indicated model into memory and return a LoadedModel object. - - This is provided for API compatability with the get_model() method - in the original model manager. However, note that LoadedModel is - not the same as the original ModelInfo that ws returned. - - :param model_name: Key of model config to be fetched. - :param base_model: Base model - :param model_type: Type of the model - :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. - - Exceptions: UnknownModelException -- model with this key not known - NotImplementedException -- a model loader was not provided at initialization time - ValueError -- more than one model matches this combination - """ - configs = self.search_by_attr(model_name, base_model, model_type) - if len(configs) == 0: - raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") - elif len(configs) > 1: - raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") - else: - return self.load_model(configs[0].key, submodel) - def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. @@ -476,29 +405,3 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) - - def _emit_load_event( - self, - context: InvocationContext, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - ) -> None: - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException() - - if not loaded: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_config=model_config, - ) - else: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - model_config=model_config, - ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py index b473444511..1f9ac56518 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_6.py @@ -6,6 +6,7 @@ from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import class Migration6Callback: def __call__(self, cursor: sqlite3.Cursor) -> None: self._recreate_model_triggers(cursor) + self._delete_ip_adapters(cursor) def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None: """ @@ -26,6 +27,22 @@ class Migration6Callback: """ ) + def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None: + """ + Delete all the IP adapters. + + The model manager will automatically find and re-add them after the migration + is done. This allows the manager to add the correct image encoder to their + configuration records. + """ + + cursor.execute( + """--sql + DELETE FROM model_config + WHERE type='ip_adapter'; + """ + ) + def build_migration_6() -> Migration: """ @@ -33,6 +50,8 @@ def build_migration_6() -> Migration: This migration does the following: - Adds the model_config_updated_at trigger if it does not exist + - Delete all ip_adapter models so that the model prober can find and + update with the correct image processor model. """ migration_6 = Migration( from_version=5, diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py index 4725181b8e..bee8909c31 100644 --- a/invokeai/backend/embeddings/model_patcher.py +++ b/invokeai/backend/embeddings/model_patcher.py @@ -64,7 +64,7 @@ class ModelPatcher: def apply_lora_unet( cls, unet: UNet2DConditionModel, - loras: List[Tuple[LoRAModelRaw, float]], + loras: Iterator[Tuple[LoRAModelRaw, float]], ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -307,7 +307,7 @@ class ONNXModelPatcher: def apply_lora_unet( cls, unet: OnnxRuntimeModel, - loras: List[Tuple[LoRAModelRaw, float]], + loras: Iterator[Tuple[LoRAModelRaw, float]], ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index b9649925e1..92ddef5ecc 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -8,8 +8,8 @@ from PIL import Image import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend import SilenceWarnings from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.silence_warnings import SilenceWarnings config = InvokeAIAppConfig.get_config() diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 9176bf1f49..b4706ea99c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -8,7 +8,6 @@ from PIL import Image from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights -from invokeai.backend.model_management.models.base import calc_model_size_by_data from .resampler import Resampler @@ -124,6 +123,9 @@ class IPAdapter: self.attn_weights.to(device=self.device, dtype=self.dtype) def calc_size(self): + # workaround for circular import + from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data + return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) def _init_image_proj_model(self, state_dict): diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 4534a4892f..9f0f774b49 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error. """ import time from enum import Enum -from typing import Literal, Optional, Type, Union, Class +from typing import Literal, Optional, Type, Union import torch from diffusers import ModelMixin @@ -335,7 +335,7 @@ class ModelConfigFactory(object): cls, model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, - dest_class: Optional[Type[Class]] = None, + dest_class: Optional[Type[ModelConfigBase]] = None, timestamp: Optional[float] = None, ) -> AnyModelConfig: """ @@ -347,14 +347,17 @@ class ModelConfigFactory(object): :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ + model: Optional[ModelConfigBase] = None if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: - model = dest_class.validate_python(model_data) + model = dest_class.model_validate(model_data) else: - model = AnyModelConfigValidator.validate_python(model_data) + # mypy doesn't typecheck TypeAdapters well? + model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + assert model is not None if key: model.key = key if timestamp: model.last_modified = timestamp - return model + return model # type: ignore diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 9d98ee3053..3d026af226 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -18,8 +18,16 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, + VaeCheckpointConfig, + VaeDiffusersConfig, +) from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.logging import InvokeAILogger @@ -32,7 +40,7 @@ class LoadedModel: config: AnyModelConfig locker: ModelLockerBase - def __enter__(self) -> AnyModel: # I think load_file() always returns a dict + def __enter__(self) -> AnyModel: """Context entry.""" self.locker.lock() return self.model @@ -171,6 +179,10 @@ class AnyModelLoader: def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}") key = cls._to_registry_key(base, type, format) + if key in cls._registry: + raise Exception( + f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" + ) cls._registry[key] = subclass return subclass diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index c1dfe729af..df83c8320d 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -169,7 +169,7 @@ class ModelLoader(ModelLoaderBase): raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e # This needs to be implemented in subclasses that handle checkpoints - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: raise NotImplementedError # This needs to be implemented in the subclass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index b1deb215b2..98d6f34cea 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -246,7 +246,7 @@ class ModelCache(ModelCacheBase[AnyModel]): def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device.""" - # These attributes are not in the base ModelMixin class but in derived classes. + # These attributes are not in the base ModelMixin class but in various derived classes. # Some models don't have these attributes, in which case they run in RAM/CPU. self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index e61e2b46a6..d446d07933 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -35,28 +35,28 @@ class ControlnetLoader(GenericDiffusersLoader): else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: assert hasattr(config, "config") config_file = config.config - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") else: - checkpoint = torch.load(weights_path, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] convert_controlnet_to_diffusers( - weights_path, + model_path, output_path, original_config_file=self._app_config.root_path / config_file, image_size=512, scan_needed=True, - from_safetensors=weights_path.suffix == ".safetensors", + from_safetensors=model_path.suffix == ".safetensors", ) return output_path diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index 03c26f3a0c..114e317f3c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -12,8 +12,9 @@ from invokeai.backend.model_manager import ( ModelType, SubModelType, ) -from invokeai.backend.model_manager.load.load_base import AnyModelLoader -from invokeai.backend.model_manager.load.load_default import ModelLoader + +from ..load_base import AnyModelLoader +from ..load_default import ModelLoader @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index a963e8403b..23b4e1fccd 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -65,7 +65,7 @@ class StableDiffusionDiffusersModel(ModelLoader): else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: assert isinstance(config, MainCheckpointConfig) variant = config.variant base = config.base @@ -75,9 +75,9 @@ class StableDiffusionDiffusersModel(ModelLoader): config_file = config.config - self._logger.info(f"Converting {weights_path} to diffusers format") + self._logger.info(f"Converting {model_path} to diffusers format") convert_ckpt_to_diffusers( - weights_path, + model_path, output_path, model_type=self.model_base_to_model_type[base], model_version=base, @@ -86,7 +86,7 @@ class StableDiffusionDiffusersModel(ModelLoader): extract_ema=True, scan_needed=True, pipeline_class=pipeline_class, - from_safetensors=weights_path.suffix == ".safetensors", + from_safetensors=model_path.suffix == ".safetensors", precision=self._torch_dtype, load_safety_checker=False, ) diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 882ae05577..3983ea7595 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -37,7 +37,7 @@ class VaeLoader(GenericDiffusersLoader): else: return True - def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: + def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: # TO DO: check whether sdxl VAE models convert. if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") @@ -46,10 +46,10 @@ class VaeLoader(GenericDiffusersLoader): "v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml" ) - if weights_path.suffix == ".safetensors": - checkpoint = safetensors.torch.load_file(weights_path, device="cpu") + if model_path.suffix == ".safetensors": + checkpoint = safetensors.torch.load_file(model_path, device="cpu") else: - checkpoint = torch.load(weights_path, map_location="cpu") + checkpoint = torch.load(model_path, map_location="cpu") # sometimes weights are hidden under "state_dict", and sometimes not if "state_dict" in checkpoint: diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 3f2d22595e..c55eee48fa 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -65,7 +65,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} other_files = set(all_files) - fp16_files - bit8_files - if variant is None: + if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF files = other_files elif variant == "fp16": files = fp16_files diff --git a/invokeai/backend/model_manager/search.py b/invokeai/backend/model_manager/search.py index a54938fdd5..f7e1e1bed7 100644 --- a/invokeai/backend/model_manager/search.py +++ b/invokeai/backend/model_manager/search.py @@ -22,11 +22,12 @@ Example usage: import os from abc import ABC, abstractmethod +from logging import Logger from pathlib import Path from typing import Callable, Optional, Set, Union from pydantic import BaseModel, Field -from logging import Logger + from invokeai.backend.util.logging import InvokeAILogger default_logger: Logger = InvokeAILogger.get_logger() diff --git a/invokeai/backend/stable_diffusion/schedulers/__init__.py b/invokeai/backend/stable_diffusion/schedulers/__init__.py index a4e9dbf9da..0b780d3ee2 100644 --- a/invokeai/backend/stable_diffusion/schedulers/__init__.py +++ b/invokeai/backend/stable_diffusion/schedulers/__init__.py @@ -1 +1,3 @@ from .schedulers import SCHEDULER_MAP # noqa: F401 + +__all__ = ["SCHEDULER_MAP"] diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 22b132370e..20b630dfc6 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -513,7 +513,7 @@ def select_and_download_models(opt: Namespace) -> None: """Prompt user for install/delete selections and execute.""" precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) # unsure how to avoid a typing complaint in the next line: config.precision is an enumerated Literal - config.precision = precision # type: ignore + config.precision = precision install_helper = InstallHelper(config, logger) installer = install_helper.installer diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index fab1fa4598..80308b57af 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -64,9 +64,7 @@ def mock_services() -> InvocationServices: latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a..7f89987a81 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -66,9 +66,7 @@ def mock_services() -> InvocationServices: latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore - model_records=None, # type: ignore download_queue=None, # type: ignore - model_install=None, # type: ignore names=None, # type: ignore performance_statistics=InvocationStatsService(), processor=DefaultInvocationProcessor(), diff --git a/tests/backend/model_manager_2/model_loading/test_model_load.py b/tests/backend/model_manager_2/model_loading/test_model_load.py new file mode 100644 index 0000000000..a7a64e91ac --- /dev/null +++ b/tests/backend/model_manager_2/model_loading/test_model_load.py @@ -0,0 +1,22 @@ +""" +Test model loading +""" + +from pathlib import Path + +from invokeai.app.services.model_install import ModelInstallServiceBase +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager.load import AnyModelLoader +from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403 + + +def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path): + store = mm2_installer.record_store + matches = store.search_by_attr(model_name="test_embedding") + assert len(matches) == 0 + key = mm2_installer.register_path(embedding_file) + loaded_model = mm2_loader.load_model(store.get_model(key)) + assert loaded_model is not None + assert loaded_model.config.key == key + with loaded_model as model: + assert isinstance(model, TextualInversionModelRaw) diff --git a/tests/backend/model_manager_2/model_manager_2_fixtures.py b/tests/backend/model_manager_2/model_manager_2_fixtures.py index d6d091befe..d85eab67dd 100644 --- a/tests/backend/model_manager_2/model_manager_2_fixtures.py +++ b/tests/backend/model_manager_2/model_manager_2_fixtures.py @@ -20,6 +20,7 @@ from invokeai.backend.model_manager.config import ( ModelFormat, ModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache from invokeai.backend.model_manager.metadata import ModelMetadataStore from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager_2.model_metadata.metadata_examples import ( @@ -89,6 +90,16 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig: return app_config +@pytest.fixture +def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader: + logger = InvokeAILogger.get_logger(config=mm2_app_config) + ram_cache = ModelCache( + logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size + ) + convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path) + return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache) + + @pytest.fixture def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL: logger = InvokeAILogger.get_logger(config=mm2_app_config)