mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
committed by
psychedelicious
parent
2b1dc74080
commit
94e8d1b6d5
@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
@ -65,7 +65,7 @@ async def list_model_records(
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
@ -81,7 +81,7 @@ async def list_model_records(
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
@ -94,24 +94,27 @@ async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.get("/meta", operation_id="list_model_summary")
|
||||
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
@ -124,24 +127,25 @@ async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
result = record_store.get_metadata(key)
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
return record_store.list_tags()
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
@ -172,9 +176,9 @@ async def update_model_record(
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -184,7 +188,7 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
@model_manager_v2_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
@ -205,7 +209,7 @@ async def del_model_record(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
@ -214,7 +218,7 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
@model_manager_v2_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
@ -229,7 +233,7 @@ async def add_model_record(
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
@ -243,10 +247,11 @@ async def add_model_record(
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
@model_manager_v2_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
@ -322,7 +327,7 @@ async def import_model(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -340,17 +345,17 @@ async def import_model(
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return list of model install jobs."""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""Return model install job corresponding to the given source."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
|
||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
@model_manager_v2_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.put(
|
||||
@model_manager_v2_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
)
|
||||
@ -451,7 +457,7 @@ async def merge(
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
InvalidModelException,
|
||||
|
Reference in New Issue
Block a user