Merge branch 'maryhipp/trigger-phrases-main' of https://github.com/invoke-ai/InvokeAI into maryhipp/trigger-phrases-main

This commit is contained in:
Mary Hipp 2024-03-04 14:56:55 -05:00
commit d87cce9174
2 changed files with 11 additions and 39 deletions

View File

@ -5,7 +5,6 @@ import pathlib
import shutil import shutil
from hashlib import sha1 from hashlib import sha1
from random import randbytes from random import randbytes
import traceback
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -245,43 +244,6 @@ async def get_model_metadata(
return result return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
original_metadata.trigger_phrases = changes.trigger_phrases
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.patch( @model_manager_router.patch(
"/i/{key}/metadata", "/i/{key}/metadata",
@ -308,10 +270,19 @@ async def update_model_metadata(
if changes.default_settings: if changes.default_settings:
original_metadata.default_settings = changes.default_settings original_metadata.default_settings = changes.default_settings
if changes.trigger_phrases:
original_metadata.trigger_phrases = changes.trigger_phrases
metadata_store.update_metadata(key, original_metadata) metadata_store.update_metadata(key, original_metadata)
else: else:
metadata_store.add_metadata( metadata_store.add_metadata(
key, BaseMetadata(name="", author="", default_settings=changes.default_settings) key,
BaseMetadata(
name="",
author="",
default_settings=changes.default_settings,
trigger_phrases=changes.trigger_phrases,
),
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View File

@ -12,6 +12,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"): class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata. """A set of changes to apply to model metadata.