From e877198de0387029fdb43281de34d901df9ff77b Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 29 Feb 2024 23:06:08 -0500 Subject: [PATCH] use dependency injector pattern in the models route --- invokeai/app/api/routers/model_manager.py | 111 ++++++++++++++-------- 1 file changed, 70 insertions(+), 41 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 50ebe5ce64..b5ebb6215e 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -7,13 +7,15 @@ from hashlib import sha1 from random import randbytes from typing import Any, Dict, List, Optional, Set -from fastapi import Body, Path, Query, Response +from fastapi import Body, Depends, Path, Query, Response from fastapi.routing import APIRouter from pydantic import BaseModel, ConfigDict, Field from starlette.exceptions import HTTPException from typing_extensions import Annotated +from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, @@ -39,6 +41,22 @@ from ..dependencies import ApiDependencies model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) +def get_services() -> InvocationServices: + """DI magic to return services from the ApiDependencies global.""" + return ApiDependencies.invoker.services + + +Services = Annotated[InvocationServices, Depends(get_services)] + + +def get_model_manager(services: Services) -> ModelManagerServiceBase: + """DI magic to return the model manager from the ApiDependencies global.""" + return services.model_manager + + +ModelManager = Annotated[ModelManagerServiceBase, Depends(get_model_manager)] + + class ModelsList(BaseModel): """Return list of configs.""" @@ -141,6 +159,7 @@ example_model_metadata = { operation_id="list_model_records", ) async def list_model_records( + model_manager: ModelManager, base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), @@ -149,7 +168,7 @@ async def list_model_records( ), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_manager.store + record_store = model_manager.store found_models: list[AnyModelConfig] = [] if base_models: for base_model in base_models: @@ -171,15 +190,14 @@ async def list_model_records( response_model=AnyModelConfig, ) async def get_model_records_by_attrs( + model_manager: ModelManager, name: str = Query(description="The name of the model"), type: ModelType = Query(description="The type of the model"), base: BaseModelType = Query(description="The base model of the model"), ) -> AnyModelConfig: """Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old model manager, which identified models by a combination of name, base and type.""" - configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr( - base_model=base, model_type=type, model_name=name - ) + configs = model_manager.store.search_by_attr(base_model=base, model_type=type, model_name=name) if not configs: raise HTTPException(status_code=404, detail="No model found with these attributes") @@ -199,10 +217,11 @@ async def get_model_records_by_attrs( }, ) async def get_model_record( + model_manager: ModelManager, key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" - record_store = ApiDependencies.invoker.services.model_manager.store + record_store = model_manager.store try: config: AnyModelConfig = record_store.get_model(key) return config @@ -212,13 +231,15 @@ async def get_model_record( @model_manager_router.get("/summary", operation_id="list_model_summary") async def list_model_summary( + model_manager: ModelManager, page: int = Query(default=0, description="The page to get"), per_page: int = Query(default=10, description="The number of models per page"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), ) -> PaginatedResults[ModelSummary]: """Gets a page of model summary data.""" - record_store = ApiDependencies.invoker.services.model_manager.store - results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) + results: PaginatedResults[ModelSummary] = model_manager.store.list_models( + page=page, per_page=per_page, order_by=order_by + ) return results @@ -234,11 +255,11 @@ async def list_model_summary( }, ) async def get_model_metadata( + model_manager: ModelManager, key: str = Path(description="Key of the model repo metadata to fetch."), ) -> Optional[AnyModelRepoMetadata]: """Get a model metadata object.""" - record_store = ApiDependencies.invoker.services.model_manager.store - result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key) + result: Optional[AnyModelRepoMetadata] = model_manager.store.get_metadata(key) return result @@ -247,10 +268,9 @@ async def get_model_metadata( "/tags", operation_id="list_tags", ) -async def list_tags() -> Set[str]: +async def list_tags(model_manager: ModelManager) -> Set[str]: """Get a unique set of all the model tags.""" - record_store = ApiDependencies.invoker.services.model_manager.store - result: Set[str] = record_store.list_tags() + result: Set[str] = model_manager.store.list_tags() return result @@ -270,6 +290,8 @@ class FoundModel(BaseModel): response_model=List[FoundModel], ) async def scan_for_models( + model_manager: ModelManager, + services: Services, scan_path: str = Query(description="Directory path to search for models", default=None), ) -> List[FoundModel]: path = pathlib.Path(scan_path) @@ -282,7 +304,7 @@ async def scan_for_models( search = ModelSearch() try: found_model_paths = search.search(path) - models_path = ApiDependencies.invoker.services.configuration.models_path + models_path = services.configuration.models_path # If the search path includes the main models directory, we need to exclude core models from the list. # TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed @@ -290,7 +312,7 @@ async def scan_for_models( core_models_path = pathlib.Path(models_path, "core").resolve() non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)] - installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr() + installed_models = model_manager.store.search_by_attr() resolved_installed_model_paths: list[str] = [] installed_model_sources: list[str] = [] @@ -328,10 +350,11 @@ async def scan_for_models( operation_id="search_by_metadata_tags", ) async def search_by_metadata_tags( + model_manager: ModelManager, tags: Set[str] = Query(default=None, description="Tags to search for"), ) -> ModelsList: """Get a list of models.""" - record_store = ApiDependencies.invoker.services.model_manager.store + record_store = model_manager.store results = record_store.search_by_metadata_tag(tags) return ModelsList(models=results) @@ -351,14 +374,15 @@ async def search_by_metadata_tags( status_code=200, ) async def update_model_record( + services: Services, key: Annotated[str, Path(description="Unique key of model")], info: Annotated[ AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) ], ) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" - logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_manager.store + logger = services.logger + record_store = services.model_manager.store try: model_response: AnyModelConfig = record_store.update_model(key, config=info) logger.info(f"Updated model: {key}") @@ -380,6 +404,7 @@ async def update_model_record( status_code=204, ) async def del_model_record( + services: Services, key: str = Path(description="Unique key of model to remove from model registry."), ) -> Response: """ @@ -388,10 +413,10 @@ async def del_model_record( The configuration record will be removed. The corresponding weights files will be deleted as well if they reside within the InvokeAI "models" directory. """ - logger = ApiDependencies.invoker.services.logger + logger = services.logger try: - installer = ApiDependencies.invoker.services.model_manager.install + installer = services.model_manager.install installer.delete(key) logger.info(f"Deleted model: {key}") return Response(status_code=204) @@ -414,13 +439,14 @@ async def del_model_record( status_code=201, ) async def add_model_record( + services: Services, config: Annotated[ AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) ], ) -> AnyModelConfig: """Add a model using the configuration information appropriate for its type.""" - logger = ApiDependencies.invoker.services.logger - record_store = ApiDependencies.invoker.services.model_manager.store + logger = services.logger + record_store = services.model_manager.store if config.key == "": config.key = sha1(randbytes(100)).hexdigest() logger.info(f"Created model {config.key} for {config.name}") @@ -450,6 +476,7 @@ async def add_model_record( status_code=201, ) async def install_model( + services: Services, source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), # TODO(MM2): Can we type this? config: Optional[Dict[str, Any]] = Body( @@ -485,10 +512,10 @@ async def install_model( See the documentation for `import_model_record` for more information on interpreting the job information returned by this route. """ - logger = ApiDependencies.invoker.services.logger + logger = services.logger try: - installer = ApiDependencies.invoker.services.model_manager.install + installer = services.model_manager.install result: ModelInstallJob = installer.heuristic_import( source=source, config=config, @@ -511,7 +538,7 @@ async def install_model( "/import", operation_id="list_model_install_jobs", ) -async def list_model_install_jobs() -> List[ModelInstallJob]: +async def list_model_install_jobs(services: Services) -> List[ModelInstallJob]: """Return the list of model install jobs. Install jobs have a numeric `id`, a `status`, and other fields that provide information on @@ -531,7 +558,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: See the example and schema below for more information. """ - jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() + jobs: List[ModelInstallJob] = services.model_manager.install.list_jobs() return jobs @@ -543,13 +570,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]: 404: {"description": "No such job"}, }, ) -async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: +async def get_model_install_job(services: Services, id: int = Path(description="Model install id")) -> ModelInstallJob: """ Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' for information on the format of the return value. """ try: - result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) + result: ModelInstallJob = services.model_manager.install.get_job_by_id(id) return result except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -564,9 +591,9 @@ async def get_model_install_job(id: int = Path(description="Model install id")) }, status_code=201, ) -async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: +async def cancel_model_install_job(services: Services, id: int = Path(description="Model install job ID")) -> None: """Cancel the model install job(s) corresponding to the given job ID.""" - installer = ApiDependencies.invoker.services.model_manager.install + installer = services.model_manager.install try: job = installer.get_job_by_id(id) except ValueError as e: @@ -582,9 +609,9 @@ async def cancel_model_install_job(id: int = Path(description="Model install job 400: {"description": "Bad request"}, }, ) -async def prune_model_install_jobs() -> Response: +async def prune_model_install_jobs(model_manager: ModelManager) -> Response: """Prune all completed and errored jobs from the install job list.""" - ApiDependencies.invoker.services.model_manager.install.prune_jobs() + model_manager.install.prune_jobs() return Response(status_code=204) @@ -596,14 +623,14 @@ async def prune_model_install_jobs() -> Response: 400: {"description": "Bad request"}, }, ) -async def sync_models_to_config() -> Response: +async def sync_models_to_config(model_manager: ModelManager) -> Response: """ Traverse the models and autoimport directories. Model files without a corresponding record in the database are added. Orphan records without a models file are deleted. """ - ApiDependencies.invoker.services.model_manager.install.sync_to_config() + model_manager.install.sync_to_config() return Response(status_code=204) @@ -621,6 +648,7 @@ async def sync_models_to_config() -> Response: }, ) async def convert_model( + services: Services, key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), ) -> AnyModelConfig: """ @@ -628,11 +656,11 @@ async def convert_model( Note that during the conversion process the key and model hash will change. The return value is the model configuration for the converted model. """ - model_manager = ApiDependencies.invoker.services.model_manager - logger = ApiDependencies.invoker.services.logger - loader = ApiDependencies.invoker.services.model_manager.load - store = ApiDependencies.invoker.services.model_manager.store - installer = ApiDependencies.invoker.services.model_manager.install + model_manager = services.model_manager + logger = services.logger + loader = services.model_manager.load + store = services.model_manager.store + installer = services.model_manager.install try: model_config = store.get_model(key) @@ -700,6 +728,7 @@ async def convert_model( }, ) async def merge( + services: Services, keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), @@ -726,11 +755,11 @@ async def merge( merge_dest_directory Specify a directory to store the merged model in [models directory] ``` """ - logger = ApiDependencies.invoker.services.logger + logger = services.logger try: logger.info(f"Merging models: {keys} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - installer = ApiDependencies.invoker.services.model_manager.install + installer = services.model_manager.install merger = ModelMerger(installer) model_names = [installer.record_store.get_model(x).name for x in keys] response = merger.merge_diffusion_models_and_save(