fix(api): allow updating of type for style preset

This commit is contained in:
Mary Hipp 2024-08-19 15:25:30 -04:00 committed by Mary Hipp Rogers
parent a85d69ce3d
commit 3e7923d072
2 changed files with 6 additions and 7 deletions

View File

@ -26,13 +26,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
) )
class StylePresetUpdateFormData(BaseModel): class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name") name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt") positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt") negative_prompt: str = Field(description="Negative prompt")
class StylePresetCreateFormData(StylePresetUpdateFormData):
type: PresetType = Field(description="Preset type") type: PresetType = Field(description="Preset type")
@ -95,9 +92,10 @@ async def update_style_preset(
try: try:
parsed_data = json.loads(data) parsed_data = json.loads(data)
validated_data = StylePresetUpdateFormData(**parsed_data) validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt negative_prompt = validated_data.negative_prompt
@ -105,7 +103,7 @@ async def update_style_preset(
raise HTTPException(status_code=400, detail="Invalid preset data") raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt) preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data) changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id) style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update( style_preset = ApiDependencies.invoker.services.style_preset_records.update(
@ -145,7 +143,7 @@ async def create_style_preset(
try: try:
parsed_data = json.loads(data) parsed_data = json.loads(data)
validated_data = StylePresetCreateFormData(**parsed_data) validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name name = validated_data.name
type = validated_data.type type = validated_data.type

View File

@ -32,6 +32,7 @@ class PresetType(str, Enum, metaclass=MetaEnum):
class StylePresetChanges(BaseModel, extra="forbid"): class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.") name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.") preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel): class StylePresetWithoutId(BaseModel):