use dependency injector pattern in the models route

This commit is contained in:
Lincoln Stein 2024-02-29 23:06:08 -05:00 committed by psychedelicious
parent f8b54930f0
commit e877198de0

View File

@ -7,13 +7,15 @@ from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi import Body, Depends, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@ -39,6 +41,22 @@ from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
def get_services() -> InvocationServices:
"""DI magic to return services from the ApiDependencies global."""
return ApiDependencies.invoker.services
Services = Annotated[InvocationServices, Depends(get_services)]
def get_model_manager(services: Services) -> ModelManagerServiceBase:
"""DI magic to return the model manager from the ApiDependencies global."""
return services.model_manager
ModelManager = Annotated[ModelManagerServiceBase, Depends(get_model_manager)]
class ModelsList(BaseModel):
"""Return list of configs."""
@ -141,6 +159,7 @@ example_model_metadata = {
operation_id="list_model_records",
)
async def list_model_records(
model_manager: ModelManager,
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
@ -149,7 +168,7 @@ async def list_model_records(
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
record_store = model_manager.store
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
@ -171,15 +190,14 @@ async def list_model_records(
response_model=AnyModelConfig,
)
async def get_model_records_by_attrs(
model_manager: ModelManager,
name: str = Query(description="The name of the model"),
type: ModelType = Query(description="The type of the model"),
base: BaseModelType = Query(description="The base model of the model"),
) -> AnyModelConfig:
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager, which identified models by a combination of name, base and type."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
base_model=base, model_type=type, model_name=name
)
configs = model_manager.store.search_by_attr(base_model=base, model_type=type, model_name=name)
if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes")
@ -199,10 +217,11 @@ async def get_model_records_by_attrs(
},
)
async def get_model_record(
model_manager: ModelManager,
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_manager.store
record_store = model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
return config
@ -212,13 +231,15 @@ async def get_model_record(
@model_manager_router.get("/summary", operation_id="list_model_summary")
async def list_model_summary(
model_manager: ModelManager,
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
results: PaginatedResults[ModelSummary] = model_manager.store.list_models(
page=page, per_page=per_page, order_by=order_by
)
return results
@ -234,11 +255,11 @@ async def list_model_summary(
},
)
async def get_model_metadata(
model_manager: ModelManager,
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
result: Optional[AnyModelRepoMetadata] = model_manager.store.get_metadata(key)
return result
@ -247,10 +268,9 @@ async def get_model_metadata(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
async def list_tags(model_manager: ModelManager) -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
result: Set[str] = model_manager.store.list_tags()
return result
@ -270,6 +290,8 @@ class FoundModel(BaseModel):
response_model=List[FoundModel],
)
async def scan_for_models(
model_manager: ModelManager,
services: Services,
scan_path: str = Query(description="Directory path to search for models", default=None),
) -> List[FoundModel]:
path = pathlib.Path(scan_path)
@ -282,7 +304,7 @@ async def scan_for_models(
search = ModelSearch()
try:
found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path
models_path = services.configuration.models_path
# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
@ -290,7 +312,7 @@ async def scan_for_models(
core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
installed_models = model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
@ -328,10 +350,11 @@ async def scan_for_models(
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
model_manager: ModelManager,
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
record_store = model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@ -351,14 +374,15 @@ async def search_by_metadata_tags(
status_code=200,
)
async def update_model_record(
services: Services,
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
logger = services.logger
record_store = services.model_manager.store
try:
model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
@ -380,6 +404,7 @@ async def update_model_record(
status_code=204,
)
async def del_model_record(
services: Services,
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
@ -388,10 +413,10 @@ async def del_model_record(
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
logger = ApiDependencies.invoker.services.logger
logger = services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
installer = services.model_manager.install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
@ -414,13 +439,14 @@ async def del_model_record(
status_code=201,
)
async def add_model_record(
services: Services,
config: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
logger = services.logger
record_store = services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
@ -450,6 +476,7 @@ async def add_model_record(
status_code=201,
)
async def install_model(
services: Services,
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
@ -485,10 +512,10 @@ async def install_model(
See the documentation for `import_model_record` for more information on
interpreting the job information returned by this route.
"""
logger = ApiDependencies.invoker.services.logger
logger = services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
installer = services.model_manager.install
result: ModelInstallJob = installer.heuristic_import(
source=source,
config=config,
@ -511,7 +538,7 @@ async def install_model(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
async def list_model_install_jobs(services: Services) -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@ -531,7 +558,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
See the example and schema below for more information.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
jobs: List[ModelInstallJob] = services.model_manager.install.list_jobs()
return jobs
@ -543,13 +570,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
async def get_model_install_job(services: Services, id: int = Path(description="Model install id")) -> ModelInstallJob:
"""
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
"""
try:
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
result: ModelInstallJob = services.model_manager.install.get_job_by_id(id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@ -564,9 +591,9 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
},
status_code=201,
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
async def cancel_model_install_job(services: Services, id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
installer = services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
@ -582,9 +609,9 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
async def prune_model_install_jobs(model_manager: ModelManager) -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
model_manager.install.prune_jobs()
return Response(status_code=204)
@ -596,14 +623,14 @@ async def prune_model_install_jobs() -> Response:
400: {"description": "Bad request"},
},
)
async def sync_models_to_config() -> Response:
async def sync_models_to_config(model_manager: ModelManager) -> Response:
"""
Traverse the models and autoimport directories.
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
model_manager.install.sync_to_config()
return Response(status_code=204)
@ -621,6 +648,7 @@ async def sync_models_to_config() -> Response:
},
)
async def convert_model(
services: Services,
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
) -> AnyModelConfig:
"""
@ -628,11 +656,11 @@ async def convert_model(
Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
model_manager = services.model_manager
logger = services.logger
loader = services.model_manager.load
store = services.model_manager.store
installer = services.model_manager.install
try:
model_config = store.get_model(key)
@ -700,6 +728,7 @@ async def convert_model(
},
)
async def merge(
services: Services,
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
@ -726,11 +755,11 @@ async def merge(
merge_dest_directory Specify a directory to store the merged model in [models directory]
```
"""
logger = ApiDependencies.invoker.services.logger
logger = services.logger
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_manager.install
installer = services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(