feat(nodes): it works

This commit is contained in:
psychedelicious 2023-05-21 22:15:44 +10:00 committed by Kent Keirsey
parent 22c34c343a
commit 5bf9891553
11 changed files with 302 additions and 481 deletions

View File

@ -63,9 +63,7 @@ class ApiDependencies:
urls = LocalUrlService()
image_file_storage = DiskImageFileStorage(
f"{output_folder}/images", metadata_service=metadata
)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")

View File

@ -1,165 +1,47 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from fastapi import HTTPException, Path
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
@image_files_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
try:
path = ApiDependencies.invoker.services.images_new.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
except Exception as e:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
@image_files_router.get(
"/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of the thumbnail to get"),
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
) -> FileResponse | Response:
image_type: ImageType = Path(
description="The type of the image whose thumbnail to get"
),
image_name: str = Path(description="The id of the image whose thumbnail to get"),
) -> FileResponse:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
path = ApiDependencies.invoker.services.images_new.get_path(
image_type=image_type, image_name=image_name, thumbnail=True
)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result
return FileResponse(path)
except Exception as e:
raise HTTPException(status_code=404)

View File

@ -71,7 +71,7 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(image_files.images_router, prefix="/api")
app.include_router(image_files.image_files_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")

View File

@ -93,34 +93,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# each time it is called. We only need the first one.
generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
context.services.images_db.set(
id=image_name,
image_dto = context.services.images_new.create(
image=generate_output.image,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
session_id=context.graph_execution_state_id,
node_id=self.id,
metadata=GeneratedImageOrLatentsMetadata(),
)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
# image_type = ImageType.RESULT
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# context.services.images.save(
# image_type, image_name, generate_output.image, metadata
# )
# context.services.images_db.set(
# id=image_name,
# image_type=ImageType.RESULT,
# image_category=ImageCategory.IMAGE,
# session_id=context.graph_execution_state_id,
# node_id=self.id,
# metadata=GeneratedImageOrLatentsMetadata(),
# )
return build_image_output(
image_type=image_type,
image_name=image_name,
image_type=image_dto.image_type,
image_name=image_dto.image_name,
image=generate_output.image,
)

View File

@ -2,8 +2,11 @@ from typing import Optional
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
class GeneratedImageOrLatentsMetadata(BaseModel):
"""Core generation metadata for an image/tensor generated in InvokeAI.
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
@ -51,20 +54,6 @@ class GeneratedImageOrLatentsMetadata(BaseModel):
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
class UploadedImageOrLatentsMetadata(BaseModel):
"""Limited metadata for an uploaded image/tensor."""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels."
)
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels."
)
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
# from another SD application or InvokeAI, so it needs to be flexible.
# If the upload is a not an image or `image_latents` tensor, this will be omitted.
extra: Optional[StrictStr] = Field(
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
)

View File

@ -1,28 +1,16 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from glob import glob
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, List
from typing import Dict, Optional
from PIL.Image import Image
import PIL.Image as PILImage
from PIL.Image import Image as PILImageType
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -48,61 +36,27 @@ class ImageFileStorageBase(ABC):
super().__init__(message)
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
# # TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# @abstractmethod
# def get_image_location(
# self, image_type: ImageType, image_name: str
# ) -> str:
# """Gets the location of an image."""
# pass
# @abstractmethod
# def get_thumbnail_location(
# self, image_type: ImageType, image_name: str
# ) -> str:
# """Gets the location of an image's thumbnail."""
# pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
"""Gets the internal path to an image or thumbnail."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
pnginfo: Optional[PngInfo] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@ -111,26 +65,20 @@ class ImageFileStorageBase(ABC):
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__cache: Dict[str, PILImageType]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
@ -143,144 +91,38 @@ class DiskImageFileStorage(ImageFileStorageBase):
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
except Exception as e:
except FileNotFoundError as e:
raise ImageFileStorageBase.ImageFileNotFoundException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except FileNotFoundError:
return False
except Exception as e:
raise e
def save(
self,
image: PILImageType,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
pnginfo: Optional[PngInfo] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(
image_type, thumbnail_name, is_thumbnail=True
)
thumbnail_image = make_thumbnail(image)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
except Exception as e:
raise ImageFileStorageBase.ImageFileSaveException from e
@ -304,10 +146,29 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileStorageBase.ImageFileDeleteException from e
def __get_cache(self, image_name: str) -> Image | None:
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
def __set_cache(self, image_name: str, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(

View File

@ -1,25 +1,18 @@
from abc import ABC, abstractmethod
import datetime
from typing import Optional
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
import sqlite3
import threading
from typing import Optional, Union
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.util.create_enum_table import create_enum_table
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.util.deserialize_image_record import (
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
@ -76,9 +69,7 @@ class ImageRecordStorageBase(ABC):
image_category: ImageCategory,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[
GeneratedImageOrLatentsMetadata | UploadedImageOrLatentsMetadata
],
metadata: Optional[ImageMetadata],
created_at: str = datetime.datetime.utcnow().isoformat(),
) -> None:
"""Saves an image record."""
@ -288,9 +279,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category: ImageCategory,
session_id: Optional[str],
node_id: Optional[str],
metadata: Union[
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None
],
metadata: Optional[ImageMetadata],
created_at: str,
) -> None:
try:
@ -306,7 +295,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category,
node_id,
session_id,
metadata
metadata,
created_at
)
VALUES (?, ?, ?, ?, ?, ?, ?);

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
import json
from logging import Logger
from typing import Optional, Union
import uuid
from PIL.Image import Image as PILImageType
from PIL import PngImagePlugin
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
)
@ -22,8 +23,95 @@ from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
class ImageServiceABC(ABC):
"""
High-level service for image management.
Provides methods for creating, retrieving, and deleting images.
"""
@abstractmethod
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path"""
pass
@abstractmethod
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's URL"""
pass
@abstractmethod
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's URL"""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
pass
@abstractmethod
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
pass
@abstractmethod
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
pass
@abstractmethod
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
pass
@abstractmethod
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageManagementService."""
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
@ -46,9 +134,7 @@ class ImageServiceDependencies:
self.logger = logger
class ImageService:
"""High-level service for image management."""
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
@ -67,21 +153,6 @@ class ImageService:
logger=logger,
)
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Creates an image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def create(
self,
image: PILImageType,
@ -89,11 +160,8 @@ class ImageService:
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
image_name = self._create_image_name(
image_type=image_type,
image_category=image_category,
@ -103,13 +171,19 @@ class ImageService:
timestamp = get_iso_timestamp()
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", json.dumps(metadata))
else:
pnginfo = None
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.files.save(
image_type=image_type,
image_name=image_name,
image=image,
metadata=metadata,
pnginfo=pnginfo,
)
self._services.records.save(
@ -144,25 +218,40 @@ class ImageService:
except ImageFileStorageBase.ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
try:
return self._services.files.get(image_type, image_name)
except ImageFileStorageBase.ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
try:
return self._services.records.get(image_type, image_name)
except ImageRecordStorageBase.ImageRecordNotFoundException:
self._services.logger.error("Failed to get image record")
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
try:
image_record = self._services.records.get(image_type, image_name)
@ -174,21 +263,11 @@ class ImageService:
return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException:
self._services.logger.error("Failed to get image DTO")
raise
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileStorageBase.ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_many(
self,
@ -197,7 +276,6 @@ class ImageService:
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
try:
results = self._services.records.get_many(
image_type,
@ -225,21 +303,47 @@ class ImageService:
total=results.total,
)
except Exception as e:
self._services.logger.error("Failed to get paginated image DTOs")
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_type: ImageType, image_name: str):
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileStorageBase.ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
raise NotImplementedError("The 'favorite' method is not implemented yet.")
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"

View File

@ -1,11 +1,10 @@
import datetime
import sqlite3
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel):
@ -19,9 +18,9 @@ class ImageRecord(BaseModel):
)
session_id: Optional[str] = Field(default=None, description="The session ID.")
node_id: Optional[str] = Field(default=None, description="The node ID.")
metadata: Optional[
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = Field(default=None, description="The image's metadata.")
metadata: Optional[ImageMetadata] = Field(
default=None, description="The image's metadata."
)
class ImageDTO(ImageRecord):
@ -46,3 +45,27 @@ def image_record_to_dto(
image_url=image_url,
thumbnail_url=thumbnail_url,
)
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
"""Deserializes an image record."""
image_dict = dict(image_row)
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
raw_metadata = image_dict.get("metadata", "{}")
metadata = ImageMetadata.parse_raw(raw_metadata)
return ImageRecord(
image_name=image_dict.get("id", "unknown"),
session_id=image_dict.get("session_id", None),
node_id=image_dict.get("node_id", None),
metadata=metadata,
image_type=image_type,
image_category=ImageCategory(
image_dict.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image_dict.get("created_at", get_iso_timestamp()),
)

View File

@ -25,8 +25,8 @@ class LocalUrlService(UrlServiceBase):
def get_image_url(self, image_type: ImageType, image_name: str) -> str:
image_basename = os.path.basename(image_name)
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}"
def get_thumbnail_url(self, image_type: ImageType, image_name: str) -> str:
thumbnail_basename = get_thumbnail_name(os.path.basename(image_name))
return f"{self._base_url}/images/{image_type.value}/thumbnails/{thumbnail_basename}"
image_basename = os.path.basename(image_name)
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail"

View File

@ -1,33 +0,0 @@
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.util.misc import get_iso_timestamp
def deserialize_image_record(image: dict) -> ImageRecord:
"""Deserializes an image record."""
# All values *should* be present, except `session_id` and `node_id`, but provide some defaults just in case
image_type = ImageType(image.get("image_type", ImageType.RESULT.value))
raw_metadata = image.get("metadata", {})
if image_type == ImageType.UPLOAD:
metadata = UploadedImageOrLatentsMetadata.parse_obj(raw_metadata)
else:
metadata = GeneratedImageOrLatentsMetadata.parse_obj(raw_metadata)
return ImageRecord(
image_name=image.get("id", "unknown"),
session_id=image.get("session_id", None),
node_id=image.get("node_id", None),
metadata=metadata,
image_type=image_type,
image_category=ImageCategory(
image.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image.get("created_at", get_iso_timestamp()),
)