use dependency injector pattern in the models route

This commit is contained in:
Lincoln Stein 2024-02-29 23:06:08 -05:00 committed by psychedelicious
parent f8b54930f0
commit e877198de0

View File

@ -7,13 +7,15 @@ from hashlib import sha1
from random import randbytes from random import randbytes
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Depends, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,
@ -39,6 +41,22 @@ from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
def get_services() -> InvocationServices:
"""DI magic to return services from the ApiDependencies global."""
return ApiDependencies.invoker.services
Services = Annotated[InvocationServices, Depends(get_services)]
def get_model_manager(services: Services) -> ModelManagerServiceBase:
"""DI magic to return the model manager from the ApiDependencies global."""
return services.model_manager
ModelManager = Annotated[ModelManagerServiceBase, Depends(get_model_manager)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
"""Return list of configs.""" """Return list of configs."""
@ -141,6 +159,7 @@ example_model_metadata = {
operation_id="list_model_records", operation_id="list_model_records",
) )
async def list_model_records( async def list_model_records(
model_manager: ModelManager,
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
@ -149,7 +168,7 @@ async def list_model_records(
), ),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store record_store = model_manager.store
found_models: list[AnyModelConfig] = [] found_models: list[AnyModelConfig] = []
if base_models: if base_models:
for base_model in base_models: for base_model in base_models:
@ -171,15 +190,14 @@ async def list_model_records(
response_model=AnyModelConfig, response_model=AnyModelConfig,
) )
async def get_model_records_by_attrs( async def get_model_records_by_attrs(
model_manager: ModelManager,
name: str = Query(description="The name of the model"), name: str = Query(description="The name of the model"),
type: ModelType = Query(description="The type of the model"), type: ModelType = Query(description="The type of the model"),
base: BaseModelType = Query(description="The base model of the model"), base: BaseModelType = Query(description="The base model of the model"),
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old """Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager, which identified models by a combination of name, base and type.""" model manager, which identified models by a combination of name, base and type."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr( configs = model_manager.store.search_by_attr(base_model=base, model_type=type, model_name=name)
base_model=base, model_type=type, model_name=name
)
if not configs: if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes") raise HTTPException(status_code=404, detail="No model found with these attributes")
@ -199,10 +217,11 @@ async def get_model_records_by_attrs(
}, },
) )
async def get_model_record( async def get_model_record(
model_manager: ModelManager,
key: str = Path(description="Key of the model record to fetch."), key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Get a model record""" """Get a model record"""
record_store = ApiDependencies.invoker.services.model_manager.store record_store = model_manager.store
try: try:
config: AnyModelConfig = record_store.get_model(key) config: AnyModelConfig = record_store.get_model(key)
return config return config
@ -212,13 +231,15 @@ async def get_model_record(
@model_manager_router.get("/summary", operation_id="list_model_summary") @model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary( async def list_model_summary(
model_manager: ModelManager,
page: int = Query(default=0, description="The page to get"), page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"), per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data.""" """Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store results: PaginatedResults[ModelSummary] = model_manager.store.list_models(
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) page=page, per_page=per_page, order_by=order_by
)
return results return results
@ -234,11 +255,11 @@ async def list_model_summary(
}, },
) )
async def get_model_metadata( async def get_model_metadata(
model_manager: ModelManager,
key: str = Path(description="Key of the model repo metadata to fetch."), key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]: ) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object.""" """Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store result: Optional[AnyModelRepoMetadata] = model_manager.store.get_metadata(key)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result return result
@ -247,10 +268,9 @@ async def get_model_metadata(
"/tags", "/tags",
operation_id="list_tags", operation_id="list_tags",
) )
async def list_tags() -> Set[str]: async def list_tags(model_manager: ModelManager) -> Set[str]:
"""Get a unique set of all the model tags.""" """Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store result: Set[str] = model_manager.store.list_tags()
result: Set[str] = record_store.list_tags()
return result return result
@ -270,6 +290,8 @@ class FoundModel(BaseModel):
response_model=List[FoundModel], response_model=List[FoundModel],
) )
async def scan_for_models( async def scan_for_models(
model_manager: ModelManager,
services: Services,
scan_path: str = Query(description="Directory path to search for models", default=None), scan_path: str = Query(description="Directory path to search for models", default=None),
) -> List[FoundModel]: ) -> List[FoundModel]:
path = pathlib.Path(scan_path) path = pathlib.Path(scan_path)
@ -282,7 +304,7 @@ async def scan_for_models(
search = ModelSearch() search = ModelSearch()
try: try:
found_model_paths = search.search(path) found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path models_path = services.configuration.models_path
# If the search path includes the main models directory, we need to exclude core models from the list. # If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed # TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
@ -290,7 +312,7 @@ async def scan_for_models(
core_models_path = pathlib.Path(models_path, "core").resolve() core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)] non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr() installed_models = model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = [] resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = [] installed_model_sources: list[str] = []
@ -328,10 +350,11 @@ async def scan_for_models(
operation_id="search_by_metadata_tags", operation_id="search_by_metadata_tags",
) )
async def search_by_metadata_tags( async def search_by_metadata_tags(
model_manager: ModelManager,
tags: Set[str] = Query(default=None, description="Tags to search for"), tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store record_store = model_manager.store
results = record_store.search_by_metadata_tag(tags) results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results) return ModelsList(models=results)
@ -351,14 +374,15 @@ async def search_by_metadata_tags(
status_code=200, status_code=200,
) )
async def update_model_record( async def update_model_record(
services: Services,
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[ info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
], ],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = services.model_manager.store
try: try:
model_response: AnyModelConfig = record_store.update_model(key, config=info) model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
@ -380,6 +404,7 @@ async def update_model_record(
status_code=204, status_code=204,
) )
async def del_model_record( async def del_model_record(
services: Services,
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response: ) -> Response:
""" """
@ -388,10 +413,10 @@ async def del_model_record(
The configuration record will be removed. The corresponding weights files will be The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory. deleted as well if they reside within the InvokeAI "models" directory.
""" """
logger = ApiDependencies.invoker.services.logger logger = services.logger
try: try:
installer = ApiDependencies.invoker.services.model_manager.install installer = services.model_manager.install
installer.delete(key) installer.delete(key)
logger.info(f"Deleted model: {key}") logger.info(f"Deleted model: {key}")
return Response(status_code=204) return Response(status_code=204)
@ -414,13 +439,14 @@ async def del_model_record(
status_code=201, status_code=201,
) )
async def add_model_record( async def add_model_record(
services: Services,
config: Annotated[ config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input) AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
], ],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type.""" """Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger logger = services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = services.model_manager.store
if config.key == "<NOKEY>": if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest() config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}") logger.info(f"Created model {config.key} for {config.name}")
@ -450,6 +476,7 @@ async def add_model_record(
status_code=201, status_code=201,
) )
async def install_model( async def install_model(
services: Services,
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
# TODO(MM2): Can we type this? # TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
@ -485,10 +512,10 @@ async def install_model(
See the documentation for `import_model_record` for more information on See the documentation for `import_model_record` for more information on
interpreting the job information returned by this route. interpreting the job information returned by this route.
""" """
logger = ApiDependencies.invoker.services.logger logger = services.logger
try: try:
installer = ApiDependencies.invoker.services.model_manager.install installer = services.model_manager.install
result: ModelInstallJob = installer.heuristic_import( result: ModelInstallJob = installer.heuristic_import(
source=source, source=source,
config=config, config=config,
@ -511,7 +538,7 @@ async def install_model(
"/import", "/import",
operation_id="list_model_install_jobs", operation_id="list_model_install_jobs",
) )
async def list_model_install_jobs() -> List[ModelInstallJob]: async def list_model_install_jobs(services: Services) -> List[ModelInstallJob]:
"""Return the list of model install jobs. """Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -531,7 +558,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
See the example and schema below for more information. See the example and schema below for more information.
""" """
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs() jobs: List[ModelInstallJob] = services.model_manager.install.list_jobs()
return jobs return jobs
@ -543,13 +570,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
404: {"description": "No such job"}, 404: {"description": "No such job"},
}, },
) )
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: async def get_model_install_job(services: Services, id: int = Path(description="Model install id")) -> ModelInstallJob:
""" """
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value. for information on the format of the return value.
""" """
try: try:
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id) result: ModelInstallJob = services.model_manager.install.get_job_by_id(id)
return result return result
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -564,9 +591,9 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
}, },
status_code=201, status_code=201,
) )
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None: async def cancel_model_install_job(services: Services, id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID.""" """Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install installer = services.model_manager.install
try: try:
job = installer.get_job_by_id(id) job = installer.get_job_by_id(id)
except ValueError as e: except ValueError as e:
@ -582,9 +609,9 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
) )
async def prune_model_install_jobs() -> Response: async def prune_model_install_jobs(model_manager: ModelManager) -> Response:
"""Prune all completed and errored jobs from the install job list.""" """Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_manager.install.prune_jobs() model_manager.install.prune_jobs()
return Response(status_code=204) return Response(status_code=204)
@ -596,14 +623,14 @@ async def prune_model_install_jobs() -> Response:
400: {"description": "Bad request"}, 400: {"description": "Bad request"},
}, },
) )
async def sync_models_to_config() -> Response: async def sync_models_to_config(model_manager: ModelManager) -> Response:
""" """
Traverse the models and autoimport directories. Traverse the models and autoimport directories.
Model files without a corresponding Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted. record in the database are added. Orphan records without a models file are deleted.
""" """
ApiDependencies.invoker.services.model_manager.install.sync_to_config() model_manager.install.sync_to_config()
return Response(status_code=204) return Response(status_code=204)
@ -621,6 +648,7 @@ async def sync_models_to_config() -> Response:
}, },
) )
async def convert_model( async def convert_model(
services: Services,
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."), key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
@ -628,11 +656,11 @@ async def convert_model(
Note that during the conversion process the key and model hash will change. Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model. The return value is the model configuration for the converted model.
""" """
model_manager = ApiDependencies.invoker.services.model_manager model_manager = services.model_manager
logger = ApiDependencies.invoker.services.logger logger = services.logger
loader = ApiDependencies.invoker.services.model_manager.load loader = services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store store = services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install installer = services.model_manager.install
try: try:
model_config = store.get_model(key) model_config = store.get_model(key)
@ -700,6 +728,7 @@ async def convert_model(
}, },
) )
async def merge( async def merge(
services: Services,
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
@ -726,11 +755,11 @@ async def merge(
merge_dest_directory Specify a directory to store the merged model in [models directory] merge_dest_directory Specify a directory to store the merged model in [models directory]
``` ```
""" """
logger = ApiDependencies.invoker.services.logger logger = services.logger
try: try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install installer = services.model_manager.install
merger = ModelMerger(installer) merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys] model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save( response = merger.merge_diffusion_models_and_save(