make model manager v2 ready for PR review

- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
Lincoln Stein
2024-02-10 18:09:45 -05:00
committed by psychedelicious
parent 2b1dc74080
commit 94e8d1b6d5
36 changed files with 680 additions and 435 deletions

View File

@ -8,9 +8,6 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
from invokeai.backend.model_manager.load.model_cache import ModelCache
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -30,9 +27,7 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -98,28 +93,10 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
model_loader = AnyModelLoader(
app_config=config,
logger=logger,
ram_cache=ModelCache(
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
),
convert_cache=ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
),
)
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
download_queue_service = DownloadQueueService(event_bus=events)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=ModelMetadataStore(db=db),
event_bus=events,
model_manager = ModelManagerService.build_model_manager(
app_config=configuration, db=db, download_queue=download_queue_service, events=events
)
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -143,9 +120,7 @@ class ApiDependencies:
invocation_cache=invocation_cache,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
class ModelsList(BaseModel):
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
tags: Set[str]
@model_records_router.get(
@model_manager_v2_router.get(
"/",
operation_id="list_model_records",
)
@ -65,7 +65,7 @@ async def list_model_records(
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
@ -81,7 +81,7 @@ async def list_model_records(
return ModelsList(models=found_models)
@model_records_router.get(
@model_manager_v2_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
@ -94,24 +94,27 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
try:
return record_store.get_model(key)
config: AnyModelConfig = record_store.get_model(key)
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary")
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
record_store = ApiDependencies.invoker.services.model_manager.store
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
return results
@model_records_router.get(
@model_manager_v2_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
@ -124,24 +127,25 @@ async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records
result = record_store.get_metadata(key)
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records
return record_store.list_tags()
record_store = ApiDependencies.invoker.services.model_manager.store
result: Set[str] = record_store.list_tags()
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
@ -172,9 +176,9 @@ async def update_model_record(
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
try:
model_response = record_store.update_model(key, config=info)
model_response: AnyModelConfig = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -184,7 +188,7 @@ async def update_model_record(
return model_response
@model_records_router.delete(
@model_manager_v2_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
@ -205,7 +209,7 @@ async def del_model_record(
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
@ -214,7 +218,7 @@ async def del_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
@model_manager_v2_router.post(
"/i/",
operation_id="add_model_record",
responses={
@ -229,7 +233,7 @@ async def add_model_record(
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
@ -243,10 +247,11 @@ async def add_model_record(
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)
result: AnyModelConfig = record_store.get_model(config.key)
return result
@model_records_router.post(
@model_manager_v2_router.post(
"/import",
operation_id="import_model_record",
responses={
@ -322,7 +327,7 @@ async def import_model(
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
@ -340,17 +345,17 @@ async def import_model(
return result
@model_records_router.get(
@model_manager_v2_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
return jobs
@model_records_router.get(
@model_manager_v2_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete(
@model_manager_v2_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_install.prune_jobs()
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
@model_manager_v2_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
return Response(status_code=204)
@model_records_router.put(
@model_manager_v2_router.put(
"/merge",
operation_id="merge",
)
@ -451,7 +457,7 @@ async def merge(
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
installer = ApiDependencies.invoker.services.model_install
installer = ApiDependencies.invoker.services.model_manager.install
merger = ModelMerger(installer)
model_names = [installer.record_store.get_model(x).name for x in keys]
response = merger.merge_diffusion_models_and_save(

View File

@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
InvalidModelException,