diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 7bb0f23dc8..2571c50507 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -6,7 +6,7 @@ import pathlib import shutil import traceback from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from fastapi import Body, Path, Query, Response, UploadFile from fastapi.responses import FileResponse @@ -52,6 +52,13 @@ class ModelsList(BaseModel): model_config = ConfigDict(use_enum_values=True) +def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig: + """Add a cover image URL to a model configuration.""" + cover_image = dependencies.invoker.services.model_images.get_url(config.key) + config.cover_image = cover_image + return config + + ############################################################################## # These are example inputs and outputs that are used in places where Swagger # is unable to generate a correct example. @@ -118,8 +125,7 @@ async def list_model_records( record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) ) for model in found_models: - cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key) - model.cover_image = cover_image + model = add_cover_image_to_model_config(model, ApiDependencies) return ModelsList(models=found_models) @@ -160,12 +166,9 @@ async def get_model_record( key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" - record_store = ApiDependencies.invoker.services.model_manager.store try: - config: AnyModelConfig = record_store.get_model(key) - cover_image = ApiDependencies.invoker.services.model_images.get_url(key) - config.cover_image = cover_image - return config + config = ApiDependencies.invoker.services.model_manager.store.get_model(key) + return add_cover_image_to_model_config(config, ApiDependencies) except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) @@ -294,14 +297,15 @@ async def update_model_record( installer = ApiDependencies.invoker.services.model_manager.install try: record_store.update_model(key, changes=changes) - model_response: AnyModelConfig = installer.sync_model_path(key) + config = installer.sync_model_path(key) + config = add_cover_image_to_model_config(config, ApiDependencies) logger.info(f"Updated model: {key}") except UnknownModelException as e: raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - return model_response + return config @model_manager_router.get(