fix(api): add cover image to update model response

Fixes a bug where the image _appears_ to be reset when editing a model.

See: https://old.reddit.com/r/StableDiffusion/comments/1cnx40d/invoke_42_control_layers_regional_guidance_w_text/l3asdej/
This commit is contained in:
psychedelicious 2024-05-10 10:42:34 +10:00
parent 5da8cde4fc
commit 9cdb801c1c

View File

@ -6,7 +6,7 @@ import pathlib
import shutil import shutil
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -52,6 +52,13 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
config.cover_image = cover_image
return config
############################################################################## ##############################################################################
# These are example inputs and outputs that are used in places where Swagger # These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example. # is unable to generate a correct example.
@ -118,8 +125,7 @@ async def list_model_records(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
) )
for model in found_models: for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key) model = add_cover_image_to_model_config(model, ApiDependencies)
model.cover_image = cover_image
return ModelsList(models=found_models) return ModelsList(models=found_models)
@ -160,12 +166,9 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."), key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Get a model record""" """Get a model record"""
record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
config: AnyModelConfig = record_store.get_model(key) config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key) return add_cover_image_to_model_config(config, ApiDependencies)
config.cover_image = cover_image
return config
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -294,14 +297,15 @@ async def update_model_record(
installer = ApiDependencies.invoker.services.model_manager.install installer = ApiDependencies.invoker.services.model_manager.install
try: try:
record_store.update_model(key, changes=changes) record_store.update_model(key, changes=changes)
model_response: AnyModelConfig = installer.sync_model_path(key) config = installer.sync_model_path(key)
config = add_cover_image_to_model_config(config, ApiDependencies)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
return model_response return config
@model_manager_router.get( @model_manager_router.get(