style preset images

This commit is contained in:
Mary Hipp
2024-08-07 09:58:27 -04:00
parent 2604fd9fde
commit cc96dcf0ed
28 changed files with 786 additions and 255 deletions

View File

@ -31,6 +31,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_images.style_preset_images_default import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
@ -75,6 +77,7 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_preset_images_folder = config.style_preset_images_path
db = init_db(config=config, logger=logger, image_files=image_files)
@ -111,6 +114,9 @@ class ApiDependencies:
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_images_service = StylePresetImageFileStorageDisk(
style_preset_images_folder / "style_preset_images"
)
services = InvocationServices(
board_image_records=board_image_records,
@ -137,6 +143,7 @@ class ApiDependencies:
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_images_service=style_preset_images_service,
)
ApiDependencies.invoker = Invoker(services)

View File

@ -1,10 +1,18 @@
from fastapi import APIRouter, Body, HTTPException, Path
import io
import traceback
from typing import Optional
from fastapi import APIRouter, Body, File, Form, HTTPException, Path, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetData,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
StylePresetRecordWithImage,
StylePresetWithoutId,
)
@ -15,15 +23,17 @@ style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_preset
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
200: {"model": StylePresetRecordDTO},
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
) -> StylePresetRecordDTO:
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
return ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
image = ApiDependencies.invoker.services.style_preset_images_service.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@ -32,15 +42,45 @@ async def get_style_preset(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
200: {"model": StylePresetRecordDTO},
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
style_preset_id: str = Path(description="The id of the style preset to update"),
changes: StylePresetChanges = Body(description="The updated style preset", embed=True),
) -> StylePresetRecordDTO:
name: str = Form(description="The name of the style preset to create"),
positive_prompt: str = Form(description="The positive prompt of the style preset"),
negative_prompt: str = Form(description="The negative prompt of the style preset"),
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
return ApiDependencies.invoker.services.style_preset_records.update(id=style_preset_id, changes=changes)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_images_service.save(pil_image, style_preset_id)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
ApiDependencies.invoker.services.style_preset_images_service.delete(style_preset_id)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data)
style_preset_image = ApiDependencies.invoker.services.style_preset_images_service.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(id=style_preset_id, changes=changes)
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
@ -51,6 +91,7 @@ async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
ApiDependencies.invoker.services.style_preset_images_service.delete(style_preset_id)
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@ -58,23 +99,87 @@ async def delete_style_preset(
"/",
operation_id="create_style_preset",
responses={
200: {"model": StylePresetRecordDTO},
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
style_preset: StylePresetWithoutId = Body(description="The style preset to create", embed=True),
) -> StylePresetRecordDTO:
name: str = Form(description="The name of the style preset to create"),
positive_prompt: str = Form(description="The positive prompt of the style preset"),
negative_prompt: str = Form(description="The negative prompt of the style preset"),
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
return ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_images_service.save(pil_image, new_style_preset.id)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_image = ApiDependencies.invoker.services.style_preset_images_service.get_url(new_style_preset.id)
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
200: {"model": list[StylePresetRecordDTO]},
200: {"model": list[StylePresetRecordWithImage]},
},
)
async def list_style_presets() -> list[StylePresetRecordDTO]:
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
return ApiDependencies.invoker.services.style_preset_records.get_many()
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
image = ApiDependencies.invoker.services.style_preset_images_service.get_url(preset.id)
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.style_preset_images_service.get_path(style_preset_id)
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)

View File

@ -153,6 +153,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_preset_images_path: Path = Field(default=Path("style_preset_images"), description="Path to directory for style preset images.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
if TYPE_CHECKING:
@ -63,6 +64,7 @@ class InvocationServices:
tensors: "ObjectSerializerBase[torch.Tensor]",
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_images_service: "StylePresetImageFileStorageBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@ -88,3 +90,4 @@ class InvocationServices:
self.tensors = tensors
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_images_service = style_preset_images_service

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class StylePresetImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, style_preset_id: str) -> PILImageType:
"""Retrieves a style preset image as PIL Image."""
pass
@abstractmethod
def get_path(self, style_preset_id: str) -> Path:
"""Gets the internal path to a style preset image."""
pass
@abstractmethod
def get_url(self, style_preset_id: str) -> str | None:
"""Gets the URL to fetch a style preset image."""
pass
@abstractmethod
def save(self, image: PILImageType, style_preset_id: str) -> None:
"""Saves a style preset image."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset image."""
pass

View File

@ -0,0 +1,19 @@
class StylePresetImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Style preset image file not found"):
super().__init__(message)
class StylePresetImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Style preset image file not saved"):
super().__init__(message)
class StylePresetImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Style preset image file not deleted"):
super().__init__(message)

View File

@ -0,0 +1,84 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_images.style_preset_images_common import (
StylePresetImageFileDeleteException,
StylePresetImageFileNotFoundException,
StylePresetImageFileSaveException,
)
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, style_preset_images_folder: Path):
self._style_preset_images_folder = style_preset_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, style_preset_id: str) -> PILImageType:
try:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
raise StylePresetImageFileNotFoundException
return Image.open(path)
except FileNotFoundError as e:
raise StylePresetImageFileNotFoundException from e
def save(self, image: PILImageType, style_preset_id: str) -> None:
try:
self._validate_storage_folders()
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise StylePresetImageFileSaveException from e
def get_path(self, style_preset_id: str) -> Path:
path = self._style_preset_images_folder / (style_preset_id + ".webp")
return path
def get_url(self, style_preset_id: str) -> str | None:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"
return url
def delete(self, style_preset_id: str) -> None:
try:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
raise StylePresetImageFileNotFoundException
send2trash(path)
except Exception as e:
raise StylePresetImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -39,3 +39,7 @@ class StylePresetRecordDTO(StylePresetWithoutId):
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image")

View File

@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass
@abstractmethod
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass

View File

@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"
def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"