mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager. - Update invocations to use new load interface. - Fixed many but not all type checking errors in the invocations. Most were unrelated to model manager - Updated routes. All the new routes live under the route tag `model_manager_v2`. To avoid confusion with the old routes, they have the URL prefix `/api/v2/models`. The old routes have been de-registered. - Added a pytest for the loader. - Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
committed by
psychedelicious
parent
2b1dc74080
commit
94e8d1b6d5
@ -8,9 +8,6 @@ from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMe
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@ -30,9 +27,7 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
@ -98,28 +93,10 @@ class ApiDependencies:
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
model_loader = AnyModelLoader(
|
||||
app_config=config,
|
||||
logger=logger,
|
||||
ram_cache=ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
),
|
||||
convert_cache=ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
),
|
||||
)
|
||||
model_record_service = ModelRecordServiceSQL(db=db, loader=model_loader)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config,
|
||||
record_store=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
metadata_store=ModelMetadataStore(db=db),
|
||||
event_bus=events,
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration, db=db, download_queue=download_queue_service, events=events
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger) # TO DO: legacy model manager v1. Remove
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
@ -143,9 +120,7 @@ class ApiDependencies:
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
|
@ -32,7 +32,7 @@ from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
model_manager_v2_router = APIRouter(prefix="/v2/models", tags=["model_manager_v2"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
@ -52,7 +52,7 @@ class ModelTagSet(BaseModel):
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
@ -65,7 +65,7 @@ async def list_model_records(
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
@ -81,7 +81,7 @@ async def list_model_records(
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
@ -94,24 +94,27 @@ async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.get("/meta", operation_id="list_model_summary")
|
||||
@model_manager_v2_router.get("/meta", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
@ -124,24 +127,25 @@ async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
result = record_store.get_metadata(key)
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
return record_store.list_tags()
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
@ -149,12 +153,12 @@ async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
@ -172,9 +176,9 @@ async def update_model_record(
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -184,7 +188,7 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
@model_manager_v2_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
@ -205,7 +209,7 @@ async def del_model_record(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
@ -214,7 +218,7 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
@model_manager_v2_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
@ -229,7 +233,7 @@ async def add_model_record(
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
@ -243,10 +247,11 @@ async def add_model_record(
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
@model_manager_v2_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
@ -322,7 +327,7 @@ async def import_model(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -340,17 +345,17 @@ async def import_model(
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return list of model install jobs."""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
@model_manager_v2_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
@ -361,12 +366,13 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""Return model install job corresponding to the given source."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
|
||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
@model_manager_v2_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
@ -377,7 +383,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
@ -385,7 +391,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
@ -395,11 +401,11 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
@model_manager_v2_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
@ -414,11 +420,11 @@ async def sync_models_to_config() -> Response:
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.put(
|
||||
@model_manager_v2_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
)
|
||||
@ -451,7 +457,7 @@ async def merge(
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
@ -8,8 +8,7 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management import BaseModelType, MergeInterpolationMethod, ModelType
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
InvalidModelException,
|
||||
|
Reference in New Issue
Block a user