make model manager v2 ready for PR review

- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
Lincoln Stein
2024-02-10 18:09:45 -05:00
committed by psychedelicious
parent 2b1dc74080
commit 94e8d1b6d5
36 changed files with 680 additions and 435 deletions

View File

@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
class ModelsList(BaseModel):
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
tags: Set[str]
@model_records_router.get(
@model_manager_v2_router.get(
"/",
operation_id="list_model_records",
)
@ -65,7 +65,7 @@ async def list_model_records(
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
@ -81,7 +81,7 @@ async def list_model_records(
return ModelsList(models=found_models)
@model_records_router.get(
@model_manager_v2_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
@ -94,24 +94,27 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
try:
return record_store.get_model(key)
config: AnyModelConfig = record_store.get_model(key)
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary")
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_records_router.get(
@model_manager_v2_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
@ -124,24 +127,25 @@ async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records
result = record_store.get_metadata(key)
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records
return record_store.list_tags()
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
@ -172,9 +176,9 @@ async def update_model_record(
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response = record_store.update_model(key, config=info)
model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -184,7 +188,7 @@ async def update_model_record(
return model_response
@model_records_router.delete(
@model_manager_v2_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
@ -205,7 +209,7 @@ async def del_model_record(
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
@ -214,7 +218,7 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
@model_manager_v2_router.post(
"/i/",
operation_id="add_model_record",
responses={
@ -229,7 +233,7 @@ async def add_model_record(
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
@ -243,10 +247,11 @@ async def add_model_record(
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)
result: AnyModelConfig = record_store.get_model(config.key)
return result
@model_records_router.post(
@model_manager_v2_router.post(
"/import",
operation_id="import_model_record",
responses={
@ -322,7 +327,7 @@ async def import_model(
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
@ -340,17 +345,17 @@ async def import_model(
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs
@model_records_router.get(
@model_manager_v2_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete(
@model_manager_v2_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_install.prune_jobs()
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
return Response(status_code=204)
@model_records_router.put(
@model_manager_v2_router.put(
"/merge",
operation_id="merge",
)
@ -451,7 +457,7 @@ async def merge(
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(

View File

@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
InvalidModelException,