mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use dependency injector pattern in the models route
This commit is contained in:
parent
f8b54930f0
commit
e877198de0
@ -7,13 +7,15 @@ from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi import Body, Depends, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@ -39,6 +41,22 @@ from ..dependencies import ApiDependencies
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
|
||||
def get_services() -> InvocationServices:
|
||||
"""DI magic to return services from the ApiDependencies global."""
|
||||
return ApiDependencies.invoker.services
|
||||
|
||||
|
||||
Services = Annotated[InvocationServices, Depends(get_services)]
|
||||
|
||||
|
||||
def get_model_manager(services: Services) -> ModelManagerServiceBase:
|
||||
"""DI magic to return the model manager from the ApiDependencies global."""
|
||||
return services.model_manager
|
||||
|
||||
|
||||
ModelManager = Annotated[ModelManagerServiceBase, Depends(get_model_manager)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
@ -141,6 +159,7 @@ example_model_metadata = {
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
model_manager: ModelManager,
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
@ -149,7 +168,7 @@ async def list_model_records(
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
record_store = model_manager.store
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
@ -171,15 +190,14 @@ async def list_model_records(
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def get_model_records_by_attrs(
|
||||
model_manager: ModelManager,
|
||||
name: str = Query(description="The name of the model"),
|
||||
type: ModelType = Query(description="The type of the model"),
|
||||
base: BaseModelType = Query(description="The base model of the model"),
|
||||
) -> AnyModelConfig:
|
||||
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
|
||||
model manager, which identified models by a combination of name, base and type."""
|
||||
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
|
||||
base_model=base, model_type=type, model_name=name
|
||||
)
|
||||
configs = model_manager.store.search_by_attr(base_model=base, model_type=type, model_name=name)
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||
|
||||
@ -199,10 +217,11 @@ async def get_model_records_by_attrs(
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
model_manager: ModelManager,
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
record_store = model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
return config
|
||||
@ -212,13 +231,15 @@ async def get_model_record(
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
model_manager: ModelManager,
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
results: PaginatedResults[ModelSummary] = model_manager.store.list_models(
|
||||
page=page, per_page=per_page, order_by=order_by
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@ -234,11 +255,11 @@ async def list_model_summary(
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
model_manager: ModelManager,
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
result: Optional[AnyModelRepoMetadata] = model_manager.store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
@ -247,10 +268,9 @@ async def get_model_metadata(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
async def list_tags(model_manager: ModelManager) -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
result: Set[str] = model_manager.store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
@ -270,6 +290,8 @@ class FoundModel(BaseModel):
|
||||
response_model=List[FoundModel],
|
||||
)
|
||||
async def scan_for_models(
|
||||
model_manager: ModelManager,
|
||||
services: Services,
|
||||
scan_path: str = Query(description="Directory path to search for models", default=None),
|
||||
) -> List[FoundModel]:
|
||||
path = pathlib.Path(scan_path)
|
||||
@ -282,7 +304,7 @@ async def scan_for_models(
|
||||
search = ModelSearch()
|
||||
try:
|
||||
found_model_paths = search.search(path)
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
models_path = services.configuration.models_path
|
||||
|
||||
# If the search path includes the main models directory, we need to exclude core models from the list.
|
||||
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
|
||||
@ -290,7 +312,7 @@ async def scan_for_models(
|
||||
core_models_path = pathlib.Path(models_path, "core").resolve()
|
||||
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
installed_models = model_manager.store.search_by_attr()
|
||||
resolved_installed_model_paths: list[str] = []
|
||||
installed_model_sources: list[str] = []
|
||||
|
||||
@ -328,10 +350,11 @@ async def scan_for_models(
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
model_manager: ModelManager,
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
record_store = model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
@ -351,14 +374,15 @@ async def search_by_metadata_tags(
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_record(
|
||||
services: Services,
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
logger = services.logger
|
||||
record_store = services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
@ -380,6 +404,7 @@ async def update_model_record(
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
services: Services,
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@ -388,10 +413,10 @@ async def del_model_record(
|
||||
The configuration record will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
logger = services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer = services.model_manager.install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
@ -414,13 +439,14 @@ async def del_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
services: Services,
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
logger = services.logger
|
||||
record_store = services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
@ -450,6 +476,7 @@ async def add_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def install_model(
|
||||
services: Services,
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
@ -485,10 +512,10 @@ async def install_model(
|
||||
See the documentation for `import_model_record` for more information on
|
||||
interpreting the job information returned by this route.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
logger = services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer = services.model_manager.install
|
||||
result: ModelInstallJob = installer.heuristic_import(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -511,7 +538,7 @@ async def install_model(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def list_model_install_jobs(services: Services) -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@ -531,7 +558,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
||||
jobs: List[ModelInstallJob] = services.model_manager.install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@ -543,13 +570,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
async def get_model_install_job(services: Services, id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""
|
||||
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||
for information on the format of the return value.
|
||||
"""
|
||||
try:
|
||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
||||
result: ModelInstallJob = services.model_manager.install.get_job_by_id(id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -564,9 +591,9 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
async def cancel_model_install_job(services: Services, id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer = services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
@ -582,9 +609,9 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
async def prune_model_install_jobs(model_manager: ModelManager) -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@ -596,14 +623,14 @@ async def prune_model_install_jobs() -> Response:
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
async def sync_models_to_config(model_manager: ModelManager) -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories.
|
||||
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
||||
model_manager.install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@ -621,6 +648,7 @@ async def sync_models_to_config() -> Response:
|
||||
},
|
||||
)
|
||||
async def convert_model(
|
||||
services: Services,
|
||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
@ -628,11 +656,11 @@ async def convert_model(
|
||||
Note that during the conversion process the key and model hash will change.
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
model_manager = services.model_manager
|
||||
logger = services.logger
|
||||
loader = services.model_manager.load
|
||||
store = services.model_manager.store
|
||||
installer = services.model_manager.install
|
||||
|
||||
try:
|
||||
model_config = store.get_model(key)
|
||||
@ -700,6 +728,7 @@ async def convert_model(
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
services: Services,
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
@ -726,11 +755,11 @@ async def merge(
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
logger = services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer = services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
|
Loading…
Reference in New Issue
Block a user