mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): it works
This commit is contained in:
parent
22c34c343a
commit
5bf9891553
@ -63,9 +63,7 @@ class ApiDependencies:
|
|||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
|
|
||||||
image_file_storage = DiskImageFileStorage(
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
f"{output_folder}/images", metadata_service=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
@ -1,165 +1,47 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import io
|
from fastapi import HTTPException, Path
|
||||||
from datetime import datetime, timezone
|
from fastapi.responses import FileResponse
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
|
||||||
from fastapi.responses import FileResponse, Response
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
|
||||||
from invokeai.app.api.models.images import (
|
|
||||||
ImageResponse,
|
|
||||||
ImageResponseMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
|
image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
|
||||||
|
|
||||||
|
|
||||||
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
@image_files_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||||
# async def get_image(
|
|
||||||
# image_type: ImageType = Path(description="The type of image to get"),
|
|
||||||
# image_name: str = Path(description="The name of the image to get"),
|
|
||||||
# ) -> FileResponse:
|
|
||||||
# """Gets an image"""
|
|
||||||
|
|
||||||
# path = ApiDependencies.invoker.services.images.get_path(
|
|
||||||
# image_type=image_type, image_name=image_name
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
# return FileResponse(path)
|
|
||||||
# else:
|
|
||||||
# raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
|
||||||
async def get_image(
|
async def get_image(
|
||||||
image_type: ImageType = Path(description="The type of the image to get"),
|
image_type: ImageType = Path(description="The type of the image to get"),
|
||||||
image_name: str = Path(description="The id of the image to get"),
|
image_name: str = Path(description="The id of the image to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""Gets an image"""
|
"""Gets an image"""
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
try:
|
||||||
|
path = ApiDependencies.invoker.services.images_new.get_path(
|
||||||
image_type=image_type, image_name=image_name
|
image_type=image_type, image_name=image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
return FileResponse(path)
|
return FileResponse(path)
|
||||||
else:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
@image_files_router.get(
|
||||||
async def delete_image(
|
"/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail"
|
||||||
image_type: ImageType = Path(description="The type of the image to delete"),
|
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
|
||||||
) -> None:
|
|
||||||
"""Deletes an image and its thumbnail"""
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.images.delete(
|
|
||||||
image_type=image_type, image_name=image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
|
||||||
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
|
|
||||||
)
|
)
|
||||||
async def get_thumbnail(
|
async def get_thumbnail(
|
||||||
image_type: ImageType = Path(description="The type of the thumbnail to get"),
|
image_type: ImageType = Path(
|
||||||
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
|
description="The type of the image whose thumbnail to get"
|
||||||
) -> FileResponse | Response:
|
),
|
||||||
|
image_name: str = Path(description="The id of the image whose thumbnail to get"),
|
||||||
|
) -> FileResponse:
|
||||||
"""Gets a thumbnail"""
|
"""Gets a thumbnail"""
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
|
||||||
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
return FileResponse(path)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
|
||||||
"/uploads/",
|
|
||||||
operation_id="upload_image",
|
|
||||||
responses={
|
|
||||||
201: {
|
|
||||||
"description": "The image was uploaded successfully",
|
|
||||||
"model": ImageResponse,
|
|
||||||
},
|
|
||||||
415: {"description": "Image upload failed"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
)
|
|
||||||
async def upload_image(
|
|
||||||
file: UploadFile, image_type: ImageType, request: Request, response: Response
|
|
||||||
) -> ImageResponse:
|
|
||||||
if not file.content_type.startswith("image"):
|
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
|
||||||
|
|
||||||
contents = await file.read()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(io.BytesIO(contents))
|
path = ApiDependencies.invoker.services.images_new.get_path(
|
||||||
except:
|
image_type=image_type, image_name=image_name, thumbnail=True
|
||||||
# Error opening the image
|
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
|
||||||
|
|
||||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
|
||||||
|
|
||||||
saved_image = ApiDependencies.invoker.services.images.save(
|
|
||||||
image_type, filename, img
|
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
return FileResponse(path)
|
||||||
|
except Exception as e:
|
||||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
raise HTTPException(status_code=404)
|
||||||
image_type, saved_image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
|
||||||
image_type, saved_image.image_name, True
|
|
||||||
)
|
|
||||||
|
|
||||||
res = ImageResponse(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=saved_image.image_name,
|
|
||||||
image_url=image_url,
|
|
||||||
thumbnail_url=thumbnail_url,
|
|
||||||
metadata=ImageResponseMetadata(
|
|
||||||
created=saved_image.created,
|
|
||||||
width=img.width,
|
|
||||||
height=img.height,
|
|
||||||
invokeai=invokeai_metadata,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
response.status_code = 201
|
|
||||||
response.headers["Location"] = image_url
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
|
||||||
"/",
|
|
||||||
operation_id="list_images",
|
|
||||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
|
||||||
)
|
|
||||||
async def list_images(
|
|
||||||
image_type: ImageType = Query(
|
|
||||||
default=ImageType.RESULT, description="The type of images to get"
|
|
||||||
),
|
|
||||||
page: int = Query(default=0, description="The page of images to get"),
|
|
||||||
per_page: int = Query(default=10, description="The number of images per page"),
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
"""Gets a list of images"""
|
|
||||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
|
||||||
return result
|
|
||||||
|
@ -71,7 +71,7 @@ async def shutdown_event():
|
|||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(image_files.images_router, prefix="/api")
|
app.include_router(image_files.image_files_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
|
@ -93,34 +93,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
generate_output = next(outputs)
|
generate_output = next(outputs)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
image_dto = context.services.images_new.create(
|
||||||
# TODO: pre-seed?
|
image=generate_output.image,
|
||||||
# TODO: can this return multiple results? Should it?
|
|
||||||
image_type = ImageType.RESULT
|
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(
|
|
||||||
image_type, image_name, generate_output.image, metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images_db.set(
|
|
||||||
id=image_name,
|
|
||||||
image_type=ImageType.RESULT,
|
image_type=ImageType.RESULT,
|
||||||
image_category=ImageCategory.IMAGE,
|
image_category=ImageCategory.IMAGE,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
metadata=GeneratedImageOrLatentsMetadata(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
|
# TODO: pre-seed?
|
||||||
|
# TODO: can this return multiple results? Should it?
|
||||||
|
# image_type = ImageType.RESULT
|
||||||
|
# image_name = context.services.images.create_name(
|
||||||
|
# context.graph_execution_state_id, self.id
|
||||||
|
# )
|
||||||
|
|
||||||
|
# metadata = context.services.metadata.build_metadata(
|
||||||
|
# session_id=context.graph_execution_state_id, node=self
|
||||||
|
# )
|
||||||
|
|
||||||
|
# context.services.images.save(
|
||||||
|
# image_type, image_name, generate_output.image, metadata
|
||||||
|
# )
|
||||||
|
|
||||||
|
# context.services.images_db.set(
|
||||||
|
# id=image_name,
|
||||||
|
# image_type=ImageType.RESULT,
|
||||||
|
# image_category=ImageCategory.IMAGE,
|
||||||
|
# session_id=context.graph_execution_state_id,
|
||||||
|
# node_id=self.id,
|
||||||
|
# metadata=GeneratedImageOrLatentsMetadata(),
|
||||||
|
# )
|
||||||
|
|
||||||
return build_image_output(
|
return build_image_output(
|
||||||
image_type=image_type,
|
image_type=image_dto.image_type,
|
||||||
image_name=image_name,
|
image_name=image_dto.image_name,
|
||||||
image=generate_output.image,
|
image=generate_output.image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,8 +2,11 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
|
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
|
||||||
|
|
||||||
|
|
||||||
class GeneratedImageOrLatentsMetadata(BaseModel):
|
class ImageMetadata(BaseModel):
|
||||||
"""Core generation metadata for an image/tensor generated in InvokeAI.
|
"""
|
||||||
|
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||||
|
|
||||||
|
Also includes any metadata from the image's PNG tEXt chunks.
|
||||||
|
|
||||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
||||||
|
|
||||||
@ -51,20 +54,6 @@ class GeneratedImageOrLatentsMetadata(BaseModel):
|
|||||||
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
|
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
|
||||||
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
|
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
|
||||||
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
|
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
|
||||||
|
|
||||||
|
|
||||||
class UploadedImageOrLatentsMetadata(BaseModel):
|
|
||||||
"""Limited metadata for an uploaded image/tensor."""
|
|
||||||
|
|
||||||
width: Optional[StrictInt] = Field(
|
|
||||||
default=None, description="Width of the image/tensor in pixels."
|
|
||||||
)
|
|
||||||
height: Optional[StrictInt] = Field(
|
|
||||||
default=None, description="Height of the image/tensor in pixels."
|
|
||||||
)
|
|
||||||
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
|
|
||||||
# from another SD application or InvokeAI, so it needs to be flexible.
|
|
||||||
# If the upload is a not an image or `image_latents` tensor, this will be omitted.
|
|
||||||
extra: Optional[StrictStr] = Field(
|
extra: Optional[StrictStr] = Field(
|
||||||
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
|
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
|
||||||
)
|
)
|
||||||
|
@ -1,28 +1,16 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, List
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image as PILImageType
|
||||||
import PIL.Image as PILImage
|
from PIL import Image
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
from invokeai.app.api.models.images import (
|
|
||||||
ImageResponse,
|
|
||||||
ImageResponseMetadata,
|
|
||||||
SavedImage,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.services.metadata import (
|
|
||||||
InvokeAIMetadata,
|
|
||||||
MetadataServiceBase,
|
|
||||||
build_invokeai_metadata_pnginfo,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
from invokeai.app.util.misc import get_timestamp
|
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
|
||||||
@ -48,61 +36,27 @@ class ImageFileStorageBase(ABC):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
"""Retrieves an image as PIL Image."""
|
"""Retrieves an image as PIL Image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
# # TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def list(
|
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
"""Gets a paginated list of images."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets the internal path to an image or its thumbnail."""
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the external URI to an image or its thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# def get_image_location(
|
|
||||||
# self, image_type: ImageType, image_name: str
|
|
||||||
# ) -> str:
|
|
||||||
# """Gets the location of an image."""
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# @abstractmethod
|
|
||||||
# def get_thumbnail_location(
|
|
||||||
# self, image_type: ImageType, image_name: str
|
|
||||||
# ) -> str:
|
|
||||||
# """Gets the location of an image's thumbnail."""
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates an image path."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image: Image,
|
pnginfo: Optional[PngInfo] = None,
|
||||||
metadata: InvokeAIMetadata | None = None,
|
thumbnail_size: int = 256,
|
||||||
) -> SavedImage:
|
) -> None:
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -111,26 +65,20 @@ class ImageFileStorageBase(ABC):
|
|||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_name(self, context_id: str, node_id: str) -> str:
|
|
||||||
"""Creates a unique contextual image filename."""
|
|
||||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
|
||||||
|
|
||||||
|
|
||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
__output_folder: str
|
__output_folder: str
|
||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[str, Image]
|
__cache: Dict[str, PILImageType]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
__metadata_service: MetadataServiceBase
|
|
||||||
|
|
||||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
def __init__(self, output_folder: str):
|
||||||
self.__output_folder = output_folder
|
self.__output_folder = output_folder
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
self.__metadata_service = metadata_service
|
|
||||||
|
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -143,144 +91,38 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def list(
|
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
dir_path = os.path.join(self.__output_folder, image_type)
|
|
||||||
image_paths = glob(f"{dir_path}/*.png")
|
|
||||||
count = len(image_paths)
|
|
||||||
|
|
||||||
sorted_image_paths = sorted(
|
|
||||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
page_of_image_paths = sorted_image_paths[
|
|
||||||
page * per_page : (page + 1) * per_page
|
|
||||||
]
|
|
||||||
|
|
||||||
page_of_images: List[ImageResponse] = []
|
|
||||||
|
|
||||||
for path in page_of_image_paths:
|
|
||||||
filename = os.path.basename(path)
|
|
||||||
img = PILImage.open(path)
|
|
||||||
|
|
||||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
|
||||||
|
|
||||||
page_of_images.append(
|
|
||||||
ImageResponse(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=filename,
|
|
||||||
# TODO: DiskImageStorage should not be building URLs...?
|
|
||||||
image_url=self.get_uri(image_type, filename),
|
|
||||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
|
||||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
|
||||||
metadata=ImageResponseMetadata(
|
|
||||||
created=int(os.path.getctime(path)),
|
|
||||||
width=img.width,
|
|
||||||
height=img.height,
|
|
||||||
invokeai=invokeai_metadata,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
page_count_trunc = int(count / per_page)
|
|
||||||
page_count_mod = count % per_page
|
|
||||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
|
||||||
|
|
||||||
return PaginatedResults[ImageResponse](
|
|
||||||
items=page_of_images,
|
|
||||||
page=page,
|
|
||||||
pages=page_count,
|
|
||||||
per_page=per_page,
|
|
||||||
total=count,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
|
|
||||||
image = PILImage.open(image_path)
|
image = Image.open(image_path)
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except FileNotFoundError as e:
|
||||||
raise ImageFileStorageBase.ImageFileNotFoundException from e
|
raise ImageFileStorageBase.ImageFileNotFoundException from e
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
def get_path(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
|
||||||
path = os.path.join(
|
|
||||||
self.__output_folder, image_type, "thumbnails", basename
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
path = os.path.join(self.__output_folder, image_type, basename)
|
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
|
||||||
|
|
||||||
return abspath
|
|
||||||
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
|
||||||
thumbnail_basename = get_thumbnail_name(basename)
|
|
||||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
|
||||||
else:
|
|
||||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
|
||||||
|
|
||||||
return uri
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
try:
|
|
||||||
os.stat(path)
|
|
||||||
return True
|
|
||||||
except FileNotFoundError:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image: Image,
|
pnginfo: Optional[PngInfo] = None,
|
||||||
metadata: InvokeAIMetadata | None = None,
|
thumbnail_size: int = 256,
|
||||||
) -> SavedImage:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
|
||||||
# TODO: Reading the image and then saving it strips the metadata...
|
|
||||||
if metadata:
|
|
||||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
else:
|
|
||||||
image.save(image_path) # this saved image has an empty info
|
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(
|
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||||
image_type, thumbnail_name, is_thumbnail=True
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
)
|
|
||||||
thumbnail_image = make_thumbnail(image)
|
|
||||||
thumbnail_image.save(thumbnail_path)
|
thumbnail_image.save(thumbnail_path)
|
||||||
|
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||||
|
|
||||||
return SavedImage(
|
|
||||||
image_name=image_name,
|
|
||||||
thumbnail_name=thumbnail_name,
|
|
||||||
created=int(os.path.getctime(image_path)),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileStorageBase.ImageFileSaveException from e
|
raise ImageFileStorageBase.ImageFileSaveException from e
|
||||||
|
|
||||||
@ -304,10 +146,29 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileStorageBase.ImageFileDeleteException from e
|
raise ImageFileStorageBase.ImageFileDeleteException from e
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> Image | None:
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
# strip out any relative path shenanigans
|
||||||
|
basename = os.path.basename(image_name)
|
||||||
|
|
||||||
|
if thumbnail:
|
||||||
|
thumbnail_name = get_thumbnail_name(basename)
|
||||||
|
path = os.path.join(
|
||||||
|
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = os.path.join(self.__output_folder, image_type, basename)
|
||||||
|
|
||||||
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
return abspath
|
||||||
|
|
||||||
|
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
def __set_cache(self, image_name: str, image: Image):
|
def __set_cache(self, image_name: str, image: PILImageType):
|
||||||
if not image_name in self.__cache:
|
if not image_name in self.__cache:
|
||||||
self.__cache[image_name] = image
|
self.__cache[image_name] = image
|
||||||
self.__cache_ids.put(
|
self.__cache_ids.put(
|
||||||
|
@ -1,25 +1,18 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from invokeai.app.models.metadata import (
|
|
||||||
GeneratedImageOrLatentsMetadata,
|
|
||||||
UploadedImageOrLatentsMetadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from invokeai.app.models.metadata import (
|
|
||||||
GeneratedImageOrLatentsMetadata,
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
UploadedImageOrLatentsMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ImageType,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.util.create_enum_table import create_enum_table
|
from invokeai.app.services.util.create_enum_table import create_enum_table
|
||||||
from invokeai.app.services.models.image_record import ImageRecord
|
from invokeai.app.services.models.image_record import (
|
||||||
from invokeai.app.services.util.deserialize_image_record import (
|
ImageRecord,
|
||||||
deserialize_image_record,
|
deserialize_image_record,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,9 +69,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[
|
metadata: Optional[ImageMetadata],
|
||||||
GeneratedImageOrLatentsMetadata | UploadedImageOrLatentsMetadata
|
|
||||||
],
|
|
||||||
created_at: str = datetime.datetime.utcnow().isoformat(),
|
created_at: str = datetime.datetime.utcnow().isoformat(),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
@ -288,9 +279,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Union[
|
metadata: Optional[ImageMetadata],
|
||||||
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None
|
|
||||||
],
|
|
||||||
created_at: str,
|
created_at: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@ -306,7 +295,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
image_category,
|
image_category,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata
|
metadata,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?);
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import json
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import uuid
|
import uuid
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
from PIL import PngImagePlugin
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
from invokeai.app.models.metadata import (
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
GeneratedImageOrLatentsMetadata,
|
|
||||||
UploadedImageOrLatentsMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import (
|
||||||
ImageRecordStorageBase,
|
ImageRecordStorageBase,
|
||||||
)
|
)
|
||||||
@ -22,8 +23,95 @@ from invokeai.app.services.urls import UrlServiceBase
|
|||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceABC(ABC):
|
||||||
|
"""
|
||||||
|
High-level service for image management.
|
||||||
|
|
||||||
|
Provides methods for creating, retrieving, and deleting images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
metadata: Optional[ImageMetadata] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Creates an image, storing the file and its metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
|
"""Gets an image as a PIL image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||||
|
"""Gets an image's path"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
|
||||||
|
"""Gets an image's URL"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
|
||||||
|
"""Gets an image's URL"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||||
|
"""Gets an image DTO."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
page: int = 0,
|
||||||
|
per_page: int = 10,
|
||||||
|
) -> PaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a paginated list of image DTOs."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_type: ImageType, image_name: str):
|
||||||
|
"""Deletes an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||||
|
"""Adds a tag to an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||||
|
"""Removes a tag from an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||||
|
"""Favorites an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||||
|
"""Unfavorites an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceDependencies:
|
class ImageServiceDependencies:
|
||||||
"""Service dependencies for the ImageManagementService."""
|
"""Service dependencies for the ImageService."""
|
||||||
|
|
||||||
records: ImageRecordStorageBase
|
records: ImageRecordStorageBase
|
||||||
files: ImageFileStorageBase
|
files: ImageFileStorageBase
|
||||||
@ -46,9 +134,7 @@ class ImageServiceDependencies:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
class ImageService:
|
class ImageService(ImageServiceABC):
|
||||||
"""High-level service for image management."""
|
|
||||||
|
|
||||||
_services: ImageServiceDependencies
|
_services: ImageServiceDependencies
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -67,21 +153,6 @@ class ImageService:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_image_name(
|
|
||||||
self,
|
|
||||||
image_type: ImageType,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Creates an image name."""
|
|
||||||
uuid_str = str(uuid.uuid4())
|
|
||||||
|
|
||||||
if node_id is not None and session_id is not None:
|
|
||||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
|
||||||
|
|
||||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
@ -89,11 +160,8 @@ class ImageService:
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
metadata: Optional[
|
metadata: Optional[ImageMetadata] = None,
|
||||||
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
|
|
||||||
] = None,
|
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
|
||||||
image_name = self._create_image_name(
|
image_name = self._create_image_name(
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
@ -103,13 +171,19 @@ class ImageService:
|
|||||||
|
|
||||||
timestamp = get_iso_timestamp()
|
timestamp = get_iso_timestamp()
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||||
|
else:
|
||||||
|
pnginfo = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
self._services.files.save(
|
self._services.files.save(
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image=image,
|
image=image,
|
||||||
metadata=metadata,
|
pnginfo=pnginfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._services.records.save(
|
self._services.records.save(
|
||||||
@ -144,25 +218,40 @@ class ImageService:
|
|||||||
except ImageFileStorageBase.ImageFileSaveException:
|
except ImageFileStorageBase.ImageFileSaveException:
|
||||||
self._services.logger.error("Failed to save image file")
|
self._services.logger.error("Failed to save image file")
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem saving image record and file")
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
"""Gets an image as a PIL image."""
|
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_type, image_name)
|
return self._services.files.get(image_type, image_name)
|
||||||
except ImageFileStorageBase.ImageFileNotFoundException:
|
except ImageFileStorageBase.ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self._services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image file")
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
|
||||||
try:
|
try:
|
||||||
return self._services.records.get(image_type, image_name)
|
return self._services.records.get(image_type, image_name)
|
||||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Failed to get image record")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||||
"""Gets an image DTO."""
|
|
||||||
try:
|
try:
|
||||||
image_record = self._services.records.get(image_type, image_name)
|
image_record = self._services.records.get(image_type, image_name)
|
||||||
|
|
||||||
@ -174,21 +263,11 @@ class ImageService:
|
|||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Failed to get image DTO")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str):
|
|
||||||
"""Deletes an image."""
|
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
|
||||||
try:
|
|
||||||
self._services.files.delete(image_type, image_name)
|
|
||||||
self._services.records.delete(image_type, image_name)
|
|
||||||
except ImageRecordStorageBase.ImageRecordDeleteException:
|
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
|
||||||
raise
|
|
||||||
except ImageFileStorageBase.ImageFileDeleteException:
|
|
||||||
self._services.logger.error(f"Failed to delete image file")
|
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
@ -197,7 +276,6 @@ class ImageService:
|
|||||||
page: int = 0,
|
page: int = 0,
|
||||||
per_page: int = 10,
|
per_page: int = 10,
|
||||||
) -> PaginatedResults[ImageDTO]:
|
) -> PaginatedResults[ImageDTO]:
|
||||||
"""Gets a paginated list of image DTOs."""
|
|
||||||
try:
|
try:
|
||||||
results = self._services.records.get_many(
|
results = self._services.records.get_many(
|
||||||
image_type,
|
image_type,
|
||||||
@ -225,21 +303,47 @@ class ImageService:
|
|||||||
total=results.total,
|
total=results.total,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Failed to get paginated image DTOs")
|
self._services.logger.error("Problem getting paginated image DTOs")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete(self, image_type: ImageType, image_name: str):
|
||||||
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
|
try:
|
||||||
|
self._services.files.delete(image_type, image_name)
|
||||||
|
self._services.records.delete(image_type, image_name)
|
||||||
|
except ImageRecordStorageBase.ImageRecordDeleteException:
|
||||||
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
|
raise
|
||||||
|
except ImageFileStorageBase.ImageFileDeleteException:
|
||||||
|
self._services.logger.error(f"Failed to delete image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem deleting image record and file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||||
"""Adds a tag to an image."""
|
|
||||||
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
|
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
|
||||||
|
|
||||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||||
"""Removes a tag from an image."""
|
|
||||||
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
|
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
|
||||||
|
|
||||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||||
"""Favorites an image."""
|
|
||||||
raise NotImplementedError("The 'favorite' method is not implemented yet.")
|
raise NotImplementedError("The 'favorite' method is not implemented yet.")
|
||||||
|
|
||||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||||
"""Unfavorites an image."""
|
|
||||||
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
|
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
|
||||||
|
|
||||||
|
def _create_image_name(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a unique image name."""
|
||||||
|
uuid_str = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if node_id is not None and session_id is not None:
|
||||||
|
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||||
|
|
||||||
|
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import sqlite3
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from invokeai.app.models.metadata import (
|
|
||||||
GeneratedImageOrLatentsMetadata,
|
|
||||||
UploadedImageOrLatentsMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModel):
|
class ImageRecord(BaseModel):
|
||||||
@ -19,9 +18,9 @@ class ImageRecord(BaseModel):
|
|||||||
)
|
)
|
||||||
session_id: Optional[str] = Field(default=None, description="The session ID.")
|
session_id: Optional[str] = Field(default=None, description="The session ID.")
|
||||||
node_id: Optional[str] = Field(default=None, description="The node ID.")
|
node_id: Optional[str] = Field(default=None, description="The node ID.")
|
||||||
metadata: Optional[
|
metadata: Optional[ImageMetadata] = Field(
|
||||||
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
|
default=None, description="The image's metadata."
|
||||||
] = Field(default=None, description="The image's metadata.")
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageDTO(ImageRecord):
|
class ImageDTO(ImageRecord):
|
||||||
@ -46,3 +45,27 @@ def image_record_to_dto(
|
|||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
||||||
|
"""Deserializes an image record."""
|
||||||
|
|
||||||
|
image_dict = dict(image_row)
|
||||||
|
|
||||||
|
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||||
|
|
||||||
|
raw_metadata = image_dict.get("metadata", "{}")
|
||||||
|
|
||||||
|
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||||
|
|
||||||
|
return ImageRecord(
|
||||||
|
image_name=image_dict.get("id", "unknown"),
|
||||||
|
session_id=image_dict.get("session_id", None),
|
||||||
|
node_id=image_dict.get("node_id", None),
|
||||||
|
metadata=metadata,
|
||||||
|
image_type=image_type,
|
||||||
|
image_category=ImageCategory(
|
||||||
|
image_dict.get("image_category", ImageCategory.IMAGE.value)
|
||||||
|
),
|
||||||
|
created_at=image_dict.get("created_at", get_iso_timestamp()),
|
||||||
|
)
|
||||||
|
@ -25,8 +25,8 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
|
|
||||||
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
|
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
|
||||||
image_basename = os.path.basename(image_name)
|
image_basename = os.path.basename(image_name)
|
||||||
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
|
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}"
|
||||||
|
|
||||||
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
|
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
|
||||||
thumbnail_basename = get_thumbnail_name(os.path.basename(image_name))
|
image_basename = os.path.basename(image_name)
|
||||||
return f"{self._base_url}/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail"
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
from invokeai.app.models.metadata import (
|
|
||||||
GeneratedImageOrLatentsMetadata,
|
|
||||||
UploadedImageOrLatentsMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
|
||||||
from invokeai.app.services.models.image_record import ImageRecord
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_image_record(image: dict) -> ImageRecord:
|
|
||||||
"""Deserializes an image record."""
|
|
||||||
|
|
||||||
# All values *should* be present, except `session_id` and `node_id`, but provide some defaults just in case
|
|
||||||
|
|
||||||
image_type = ImageType(image.get("image_type", ImageType.RESULT.value))
|
|
||||||
raw_metadata = image.get("metadata", {})
|
|
||||||
|
|
||||||
if image_type == ImageType.UPLOAD:
|
|
||||||
metadata = UploadedImageOrLatentsMetadata.parse_obj(raw_metadata)
|
|
||||||
else:
|
|
||||||
metadata = GeneratedImageOrLatentsMetadata.parse_obj(raw_metadata)
|
|
||||||
|
|
||||||
return ImageRecord(
|
|
||||||
image_name=image.get("id", "unknown"),
|
|
||||||
session_id=image.get("session_id", None),
|
|
||||||
node_id=image.get("node_id", None),
|
|
||||||
metadata=metadata,
|
|
||||||
image_type=image_type,
|
|
||||||
image_category=ImageCategory(
|
|
||||||
image.get("image_category", ImageCategory.IMAGE.value)
|
|
||||||
),
|
|
||||||
created_at=image.get("created_at", get_iso_timestamp()),
|
|
||||||
)
|
|
Loading…
Reference in New Issue
Block a user