InvokeAI/invokeai/app/api/routers/style_presets.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

271 lines
10 KiB
Python
Raw Normal View History

import csv
2024-08-07 13:58:27 +00:00
import io
import json
2024-08-07 13:58:27 +00:00
import traceback
from codecs import iterdecode
2024-08-07 13:58:27 +00:00
from typing import Optional
import pydantic
2024-08-08 18:21:37 +00:00
from fastapi import APIRouter, File, Form, HTTPException, Path, UploadFile
2024-08-07 13:58:27 +00:00
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field, ValidationError
from invokeai.app.api.dependencies import ApiDependencies
2024-08-07 13:58:27 +00:00
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
2024-08-07 14:36:38 +00:00
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import (
2024-08-07 13:58:27 +00:00
PresetData,
PresetType,
StylePresetChanges,
StylePresetImportListTypeAdapter,
StylePresetNotFoundError,
2024-08-07 13:58:27 +00:00
StylePresetRecordWithImage,
StylePresetWithoutId,
)
class StylePresetUpdateFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
class StylePresetCreateFormData(StylePresetUpdateFormData):
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@style_presets_router.get(
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
2024-08-07 13:58:27 +00:00
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
2024-08-07 13:58:27 +00:00
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
2024-08-09 20:27:37 +00:00
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
2024-08-07 13:58:27 +00:00
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@style_presets_router.patch(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
2024-08-07 13:58:27 +00:00
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
style_preset_id: str = Path(description="The id of the style preset to update"),
data: str = Form(description="The data of the style preset to update"),
2024-08-07 13:58:27 +00:00
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
2024-08-07 13:58:27 +00:00
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
2024-08-09 20:27:37 +00:00
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
2024-08-07 13:58:27 +00:00
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
2024-08-09 20:27:37 +00:00
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
2024-08-07 13:58:27 +00:00
try:
parsed_data = json.loads(data)
validated_data = StylePresetUpdateFormData(**parsed_data)
name = validated_data.name
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
2024-08-07 13:58:27 +00:00
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data)
2024-08-09 20:27:37 +00:00
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
style_preset_id=style_preset_id, changes=changes
)
2024-08-07 13:58:27 +00:00
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
"/i/{style_preset_id}",
operation_id="delete_style_preset",
)
async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
2024-08-07 14:36:38 +00:00
try:
2024-08-09 20:27:37 +00:00
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
2024-08-07 14:36:38 +00:00
except StylePresetImageFileNotFoundException:
pass
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@style_presets_router.post(
"/",
operation_id="create_style_preset",
responses={
2024-08-07 13:58:27 +00:00
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
2024-08-07 13:58:27 +00:00
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
data: str = Form(description="The data of the style preset to create"),
2024-08-07 13:58:27 +00:00
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
try:
parsed_data = json.loads(data)
validated_data = StylePresetCreateFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
2024-08-07 13:58:27 +00:00
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
2024-08-07 13:58:27 +00:00
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
2024-08-09 20:27:37 +00:00
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
2024-08-07 13:58:27 +00:00
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
2024-08-09 20:27:37 +00:00
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
2024-08-07 13:58:27 +00:00
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
2024-08-07 13:58:27 +00:00
200: {"model": list[StylePresetRecordWithImage]},
},
)
2024-08-07 13:58:27 +00:00
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
2024-08-07 13:58:27 +00:00
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
2024-08-09 20:27:37 +00:00
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
2024-08-07 13:58:27 +00:00
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
2024-08-09 20:27:37 +00:00
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
2024-08-07 13:58:27 +00:00
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
if file.content_type not in ["text/csv", "application/json"]:
raise HTTPException(status_code=400, detail="Unsupported file type")
try:
if file.content_type == "text/csv":
csv_reader = csv.DictReader(iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except ValidationError:
if file.content_type == "text/csv":
raise HTTPException(
status_code=400,
detail="Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt'",
)
else: # file.content_type == "application/json":
raise HTTPException(
status_code=400,
detail="Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt'",
)
finally:
file.file.close()