mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): address feedback
- Address database feedback: - Remove all the extraneous tables. Only an `images` table now: - `image_type` and `image_category` are unrestricted strings. When creating images, the provided values are checked to ensure they are a valid type and category. - Add `updated_at` and `deleted_at` columns. `deleted_at` is currently unused. - Use SQLite's built-in timestamp features to populate these. Add a trigger to update `updated_at` when the row is updated. Currently no way to update a row. - Rename the `id` column in `images` to `image_name` - Rename `ImageCategory.IMAGE` to `ImageCategory.GENERAL` - Move all exceptions outside their base classes to make them more portable. - Add `width` and `height` columns to the database. These store the actual dimensions of the image file, whereas the metadata's `width` and `height` refer to the respective generation parameters and are nullable. - Make `deserialize_image_record` take a `dict` instead of `sqlite3.Row` - Improve comments throughout - Tidy up unused code/files and some minor organisation
This commit is contained in:
committed by
Kent Keirsey
parent
021e5a2aa3
commit
035425ef24
@ -30,7 +30,7 @@ async def upload_image(
|
|||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
image_category: ImageCategory = ImageCategory.IMAGE,
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
|
@ -95,7 +95,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
image_dto = context.services.images_new.create(
|
image_dto = context.services.images_new.create(
|
||||||
image=generate_output.image,
|
image=generate_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_type=ImageType.RESULT,
|
||||||
image_category=ImageCategory.IMAGE,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
)
|
)
|
||||||
@ -119,7 +119,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# context.services.images_db.set(
|
# context.services.images_db.set(
|
||||||
# id=image_name,
|
# id=image_name,
|
||||||
# image_type=ImageType.RESULT,
|
# image_type=ImageType.RESULT,
|
||||||
# image_category=ImageCategory.IMAGE,
|
# image_category=ImageCategory.GENERAL,
|
||||||
# session_id=context.graph_execution_state_id,
|
# session_id=context.graph_execution_state_id,
|
||||||
# node_id=self.id,
|
# node_id=self.id,
|
||||||
# metadata=GeneratedImageOrLatentsMetadata(),
|
# metadata=GeneratedImageOrLatentsMetadata(),
|
||||||
|
@ -372,7 +372,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
image_dto = context.services.images_new.create(
|
image_dto = context.services.images_new.create(
|
||||||
image=image,
|
image=image,
|
||||||
image_type=ImageType.RESULT,
|
image_type=ImageType.RESULT,
|
||||||
image_category=ImageCategory.IMAGE,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.util.enum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
|
||||||
|
|
||||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
class ImageType(str, Enum, metaclass=MetaEnum):
|
||||||
@ -13,20 +13,32 @@ class ImageType(str, Enum, metaclass=MetaEnum):
|
|||||||
INTERMEDIATE = "intermediates"
|
INTERMEDIATE = "intermediates"
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidImageTypeException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ImageType.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid image type."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
||||||
|
|
||||||
IMAGE = "image"
|
GENERAL = "general"
|
||||||
CONTROL_IMAGE = "control_image"
|
CONTROL = "control"
|
||||||
OTHER = "other"
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
def is_image_type(obj):
|
class InvalidImageCategoryException(ValueError):
|
||||||
try:
|
"""Raised when a provided value is not a valid ImageCategory.
|
||||||
ImageType(obj)
|
|
||||||
except ValueError:
|
Subclasses `ValueError`.
|
||||||
return False
|
"""
|
||||||
return True
|
|
||||||
|
def __init__(self, message="Invalid image category."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
|
@ -26,50 +26,66 @@ class ImageMetadata(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The type of the ancestor node of the image output node.",
|
description="The type of the ancestor node of the image output node.",
|
||||||
)
|
)
|
||||||
|
"""The type of the ancestor node of the image output node."""
|
||||||
positive_conditioning: Optional[StrictStr] = Field(
|
positive_conditioning: Optional[StrictStr] = Field(
|
||||||
default=None, description="The positive conditioning."
|
default=None, description="The positive conditioning."
|
||||||
)
|
)
|
||||||
|
"""The positive conditioning"""
|
||||||
negative_conditioning: Optional[StrictStr] = Field(
|
negative_conditioning: Optional[StrictStr] = Field(
|
||||||
default=None, description="The negative conditioning."
|
default=None, description="The negative conditioning."
|
||||||
)
|
)
|
||||||
|
"""The negative conditioning"""
|
||||||
width: Optional[StrictInt] = Field(
|
width: Optional[StrictInt] = Field(
|
||||||
default=None, description="Width of the image/latents in pixels."
|
default=None, description="Width of the image/latents in pixels."
|
||||||
)
|
)
|
||||||
|
"""Width of the image/latents in pixels"""
|
||||||
height: Optional[StrictInt] = Field(
|
height: Optional[StrictInt] = Field(
|
||||||
default=None, description="Height of the image/latents in pixels."
|
default=None, description="Height of the image/latents in pixels."
|
||||||
)
|
)
|
||||||
|
"""Height of the image/latents in pixels"""
|
||||||
seed: Optional[StrictInt] = Field(
|
seed: Optional[StrictInt] = Field(
|
||||||
default=None, description="The seed used for noise generation."
|
default=None, description="The seed used for noise generation."
|
||||||
)
|
)
|
||||||
|
"""The seed used for noise generation"""
|
||||||
cfg_scale: Optional[StrictFloat] = Field(
|
cfg_scale: Optional[StrictFloat] = Field(
|
||||||
default=None, description="The classifier-free guidance scale."
|
default=None, description="The classifier-free guidance scale."
|
||||||
)
|
)
|
||||||
|
"""The classifier-free guidance scale"""
|
||||||
steps: Optional[StrictInt] = Field(
|
steps: Optional[StrictInt] = Field(
|
||||||
default=None, description="The number of steps used for inference."
|
default=None, description="The number of steps used for inference."
|
||||||
)
|
)
|
||||||
|
"""The number of steps used for inference"""
|
||||||
scheduler: Optional[StrictStr] = Field(
|
scheduler: Optional[StrictStr] = Field(
|
||||||
default=None, description="The scheduler used for inference."
|
default=None, description="The scheduler used for inference."
|
||||||
)
|
)
|
||||||
|
"""The scheduler used for inference"""
|
||||||
model: Optional[StrictStr] = Field(
|
model: Optional[StrictStr] = Field(
|
||||||
default=None, description="The model used for inference."
|
default=None, description="The model used for inference."
|
||||||
)
|
)
|
||||||
|
"""The model used for inference"""
|
||||||
strength: Optional[StrictFloat] = Field(
|
strength: Optional[StrictFloat] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for image-to-image/latents-to-latents.",
|
description="The strength used for image-to-image/latents-to-latents.",
|
||||||
)
|
)
|
||||||
|
"""The strength used for image-to-image/latents-to-latents."""
|
||||||
latents: Optional[StrictStr] = Field(
|
latents: Optional[StrictStr] = Field(
|
||||||
default=None, description="The ID of the initial latents."
|
default=None, description="The ID of the initial latents."
|
||||||
)
|
)
|
||||||
|
"""The ID of the initial latents"""
|
||||||
vae: Optional[StrictStr] = Field(
|
vae: Optional[StrictStr] = Field(
|
||||||
default=None, description="The VAE used for decoding."
|
default=None, description="The VAE used for decoding."
|
||||||
)
|
)
|
||||||
|
"""The VAE used for decoding"""
|
||||||
unet: Optional[StrictStr] = Field(
|
unet: Optional[StrictStr] = Field(
|
||||||
default=None, description="The UNet used dor inference."
|
default=None, description="The UNet used dor inference."
|
||||||
)
|
)
|
||||||
|
"""The UNet used dor inference"""
|
||||||
clip: Optional[StrictStr] = Field(
|
clip: Optional[StrictStr] = Field(
|
||||||
default=None, description="The CLIP Encoder used for conditioning."
|
default=None, description="The CLIP Encoder used for conditioning."
|
||||||
)
|
)
|
||||||
|
"""The CLIP Encoder used for conditioning"""
|
||||||
extra: Optional[StrictStr] = Field(
|
extra: Optional[StrictStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||||
)
|
)
|
||||||
|
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
# TODO: Make a new model for this
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from invokeai.app.util.enum import MetaEnum
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The type of a resource."""
|
|
||||||
|
|
||||||
IMAGES = "images"
|
|
||||||
TENSORS = "tensors"
|
|
||||||
|
|
||||||
|
|
||||||
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
|
||||||
# """The origin of a resource (eg image or tensor)."""
|
|
||||||
|
|
||||||
# RESULTS = "results"
|
|
||||||
# UPLOADS = "uploads"
|
|
||||||
# INTERMEDIATES = "intermediates"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TensorKind(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
|
|
||||||
|
|
||||||
IMAGE_LATENTS = "image_latents"
|
|
||||||
CONDITIONING = "conditioning"
|
|
||||||
OTHER = "other"
|
|
@ -1,578 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 40,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from abc import ABC, abstractmethod\n",
|
|
||||||
"from enum import Enum\n",
|
|
||||||
"import enum\n",
|
|
||||||
"import sqlite3\n",
|
|
||||||
"import threading\n",
|
|
||||||
"from typing import Optional, Type, TypeVar, Union\n",
|
|
||||||
"from PIL.Image import Image as PILImage\n",
|
|
||||||
"from pydantic import BaseModel, Field\n",
|
|
||||||
"from torch import Tensor"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 41,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"class ResourceOrigin(str, Enum):\n",
|
|
||||||
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" RESULTS = \"results\"\n",
|
|
||||||
" UPLOADS = \"uploads\"\n",
|
|
||||||
" INTERMEDIATES = \"intermediates\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"class ImageKind(str, Enum):\n",
|
|
||||||
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" IMAGE = \"image\"\n",
|
|
||||||
" CONTROL_IMAGE = \"control_image\"\n",
|
|
||||||
" OTHER = \"other\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"class TensorKind(str, Enum):\n",
|
|
||||||
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" IMAGE_LATENTS = \"image_latents\"\n",
|
|
||||||
" CONDITIONING = \"conditioning\"\n",
|
|
||||||
" OTHER = \"other\"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 42,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" delimiter = \", \"\n",
|
|
||||||
" values = [f\"('{e.value}')\" for e in enum]\n",
|
|
||||||
" return delimiter.join(values)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_sql_table_from_enum(\n",
|
|
||||||
" enum: Type[Enum],\n",
|
|
||||||
" table_name: str,\n",
|
|
||||||
" primary_key_name: str,\n",
|
|
||||||
" conn: sqlite3.Connection,\n",
|
|
||||||
" cursor: sqlite3.Cursor,\n",
|
|
||||||
" lock: threading.Lock,\n",
|
|
||||||
"):\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" Creates and populates a table to be used as a functional enum.\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" values_string = create_sql_values_string_from_string_enum(enum)\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" f\"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
|
|
||||||
" {primary_key_name} TEXT PRIMARY KEY\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" f\"\"\"--sql\n",
|
|
||||||
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
"# create_sql_table_from_enum(\n",
|
|
||||||
"# enum=ResourceOrigin,\n",
|
|
||||||
"# table_name=\"resource_origins\",\n",
|
|
||||||
"# primary_key_name=\"origin_name\",\n",
|
|
||||||
"# conn=conn,\n",
|
|
||||||
"# cursor=cursor,\n",
|
|
||||||
"# lock=lock,\n",
|
|
||||||
"# )\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`image_kinds` functions as an enum for the ImageType model.\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" # create_sql_table_from_enum(\n",
|
|
||||||
" # enum=ImageKind,\n",
|
|
||||||
" # table_name=\"image_kinds\",\n",
|
|
||||||
" # primary_key_name=\"kind_name\",\n",
|
|
||||||
" # conn=conn,\n",
|
|
||||||
" # cursor=cursor,\n",
|
|
||||||
" # lock=lock,\n",
|
|
||||||
" # )\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`tensor_kinds` functions as an enum for the TensorType model.\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" # create_sql_table_from_enum(\n",
|
|
||||||
" # enum=TensorKind,\n",
|
|
||||||
" # table_name=\"tensor_kinds\",\n",
|
|
||||||
" # primary_key_name=\"kind_name\",\n",
|
|
||||||
" # conn=conn,\n",
|
|
||||||
" # cursor=cursor,\n",
|
|
||||||
" # lock=lock,\n",
|
|
||||||
" # )\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`images` stores all images, regardless of type\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS images (\n",
|
|
||||||
" id TEXT PRIMARY KEY,\n",
|
|
||||||
" origin TEXT,\n",
|
|
||||||
" image_kind TEXT,\n",
|
|
||||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
|
||||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
|
||||||
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`images_results` stores additional data specific to `results` images.\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS images_results (\n",
|
|
||||||
" images_id TEXT PRIMARY KEY,\n",
|
|
||||||
" session_id TEXT NOT NULL,\n",
|
|
||||||
" node_id TEXT NOT NULL,\n",
|
|
||||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`images_intermediates` stores additional data specific to `intermediates` images\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
|
|
||||||
" images_id TEXT PRIMARY KEY,\n",
|
|
||||||
" session_id TEXT NOT NULL,\n",
|
|
||||||
" node_id TEXT NOT NULL,\n",
|
|
||||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"`images_metadata` stores basic metadata for any image type\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
|
|
||||||
" images_id TEXT PRIMARY KEY,\n",
|
|
||||||
" metadata TEXT,\n",
|
|
||||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# `tensors` table: stores references to tensor\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS tensors (\n",
|
|
||||||
" id TEXT PRIMARY KEY,\n",
|
|
||||||
" origin TEXT,\n",
|
|
||||||
" tensor_kind TEXT,\n",
|
|
||||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
|
||||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
|
||||||
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# `tensors_results` stores additional data specific to `result` tensor\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
|
|
||||||
" tensors_id TEXT PRIMARY KEY,\n",
|
|
||||||
" session_id TEXT NOT NULL,\n",
|
|
||||||
" node_id TEXT NOT NULL,\n",
|
|
||||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
|
|
||||||
" tensors_id TEXT PRIMARY KEY,\n",
|
|
||||||
" session_id TEXT NOT NULL,\n",
|
|
||||||
" node_id TEXT NOT NULL,\n",
|
|
||||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
|
||||||
" try:\n",
|
|
||||||
" lock.acquire()\n",
|
|
||||||
"\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
|
|
||||||
" tensors_id TEXT PRIMARY KEY,\n",
|
|
||||||
" metadata TEXT,\n",
|
|
||||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
|
||||||
" );\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cursor.execute(\n",
|
|
||||||
" \"\"\"--sql\n",
|
|
||||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" )\n",
|
|
||||||
" conn.commit()\n",
|
|
||||||
" finally:\n",
|
|
||||||
" lock.release()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 43,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
|
|
||||||
"if (os.path.exists(db_path)):\n",
|
|
||||||
" os.remove(db_path)\n",
|
|
||||||
"\n",
|
|
||||||
"conn = sqlite3.connect(\n",
|
|
||||||
" db_path, check_same_thread=False\n",
|
|
||||||
")\n",
|
|
||||||
"cursor = conn.cursor()\n",
|
|
||||||
"lock = threading.Lock()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 44,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"create_sql_table_from_enum(\n",
|
|
||||||
" enum=ResourceOrigin,\n",
|
|
||||||
" table_name=\"resource_origins\",\n",
|
|
||||||
" primary_key_name=\"origin_name\",\n",
|
|
||||||
" conn=conn,\n",
|
|
||||||
" cursor=cursor,\n",
|
|
||||||
" lock=lock,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"create_sql_table_from_enum(\n",
|
|
||||||
" enum=ImageKind,\n",
|
|
||||||
" table_name=\"image_kinds\",\n",
|
|
||||||
" primary_key_name=\"kind_name\",\n",
|
|
||||||
" conn=conn,\n",
|
|
||||||
" cursor=cursor,\n",
|
|
||||||
" lock=lock,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"create_sql_table_from_enum(\n",
|
|
||||||
" enum=TensorKind,\n",
|
|
||||||
" table_name=\"tensor_kinds\",\n",
|
|
||||||
" primary_key_name=\"kind_name\",\n",
|
|
||||||
" conn=conn,\n",
|
|
||||||
" cursor=cursor,\n",
|
|
||||||
" lock=lock,\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 45,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"create_images_table(conn, cursor, lock)\n",
|
|
||||||
"create_images_results_table(conn, cursor, lock)\n",
|
|
||||||
"create_images_intermediates_table(conn, cursor, lock)\n",
|
|
||||||
"create_images_metadata_table(conn, cursor, lock)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 46,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"create_tensors_table(conn, cursor, lock)\n",
|
|
||||||
"create_tensors_results_table(conn, cursor, lock)\n",
|
|
||||||
"create_tensors_intermediates_table(conn, cursor, lock)\n",
|
|
||||||
"create_tensors_metadata_table(conn, cursor, lock)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 59,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"from pydantic import StrictStr\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
|
|
||||||
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
|
|
||||||
"\n",
|
|
||||||
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
|
|
||||||
"\n",
|
|
||||||
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
" positive_conditioning: Optional[StrictStr] = Field(\n",
|
|
||||||
" default=None, description=\"The positive conditioning.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" negative_conditioning: Optional[str] = Field(\n",
|
|
||||||
" default=None, description=\"The negative conditioning.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" width: Optional[int] = Field(\n",
|
|
||||||
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" height: Optional[int] = Field(\n",
|
|
||||||
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" seed: Optional[int] = Field(\n",
|
|
||||||
" default=None, description=\"The seed used for noise generation.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" cfg_scale: Optional[float] = Field(\n",
|
|
||||||
" default=None, description=\"The classifier-free guidance scale.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" steps: Optional[int] = Field(\n",
|
|
||||||
" default=None, description=\"The number of steps used for inference.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" scheduler: Optional[str] = Field(\n",
|
|
||||||
" default=None, description=\"The scheduler used for inference.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" model: Optional[str] = Field(\n",
|
|
||||||
" default=None, description=\"The model used for inference.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" strength: Optional[float] = Field(\n",
|
|
||||||
" default=None,\n",
|
|
||||||
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
|
|
||||||
" )\n",
|
|
||||||
" image: Optional[str] = Field(\n",
|
|
||||||
" default=None, description=\"The ID of the initial image.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" tensor: Optional[str] = Field(\n",
|
|
||||||
" default=None, description=\"The ID of the initial tensor.\"\n",
|
|
||||||
" )\n",
|
|
||||||
" # Pending model refactor:\n",
|
|
||||||
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
|
|
||||||
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
|
|
||||||
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
|
|
||||||
"\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 61,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 61,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.6"
|
|
||||||
},
|
|
||||||
"orig_nbformat": 4
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
@ -14,27 +14,31 @@ from invokeai.app.models.metadata import ImageMetadata
|
|||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageFileNotFoundException(Exception):
|
||||||
|
"""Raised when an image file is not found in storage."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileSaveException(Exception):
|
||||||
|
"""Raised when an image cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileDeleteException(Exception):
|
||||||
|
"""Raised when an image cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ImageFileStorageBase(ABC):
|
class ImageFileStorageBase(ABC):
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
class ImageFileNotFoundException(Exception):
|
|
||||||
"""Raised when an image file is not found in storage."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
class ImageFileSaveException(Exception):
|
|
||||||
"""Raised when an image cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
class ImageFileDeleteException(Exception):
|
|
||||||
"""Raised when an image cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
"""Retrieves an image as PIL Image."""
|
"""Retrieves an image as PIL Image."""
|
||||||
@ -102,7 +106,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
return image
|
return image
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise ImageFileStorageBase.ImageFileNotFoundException from e
|
raise ImageFileNotFoundException from e
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
@ -130,7 +134,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileStorageBase.ImageFileSaveException from e
|
raise ImageFileSaveException from e
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
@ -150,7 +154,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
if thumbnail_path in self.__cache:
|
if thumbnail_path in self.__cache:
|
||||||
del self.__cache[thumbnail_path]
|
del self.__cache[thumbnail_path]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileStorageBase.ImageFileDeleteException from e
|
raise ImageFileDeleteException from e
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(
|
def get_path(
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from typing import Optional, cast
|
||||||
from typing import Optional, Type
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
@ -18,62 +17,32 @@ from invokeai.app.services.models.image_record import (
|
|||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
|
|
||||||
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
"""
|
class ImageRecordNotFoundException(Exception):
|
||||||
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
|
"""Raised when an image record is not found."""
|
||||||
"""
|
|
||||||
|
|
||||||
delimiter = ", "
|
def __init__(self, message="Image record not found"):
|
||||||
values = [f"('{e.value}')" for e in enum]
|
super().__init__(message)
|
||||||
return delimiter.join(values)
|
|
||||||
|
|
||||||
|
|
||||||
def create_enum_table(
|
class ImageRecordSaveException(Exception):
|
||||||
enum: Type[Enum],
|
"""Raised when an image record cannot be saved."""
|
||||||
table_name: str,
|
|
||||||
primary_key_name: str,
|
|
||||||
cursor: sqlite3.Cursor,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates and populates a table to be used as a functional enum.
|
|
||||||
"""
|
|
||||||
|
|
||||||
values_string = create_sql_values_string_from_string_enum(enum)
|
def __init__(self, message="Image record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
f"""--sql
|
class ImageRecordDeleteException(Exception):
|
||||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
"""Raised when an image record cannot be deleted."""
|
||||||
{primary_key_name} TEXT PRIMARY KEY
|
|
||||||
);
|
def __init__(self, message="Image record not deleted"):
|
||||||
"""
|
super().__init__(message)
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
f"""--sql
|
|
||||||
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordStorageBase(ABC):
|
class ImageRecordStorageBase(ABC):
|
||||||
"""Low-level service responsible for interfacing with the image record store."""
|
"""Low-level service responsible for interfacing with the image record store."""
|
||||||
|
|
||||||
class ImageRecordNotFoundException(Exception):
|
# TODO: Implement an `update()` method
|
||||||
"""Raised when an image record is not found."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
class ImageRecordSaveException(Exception):
|
|
||||||
"""Raised when an image record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
class ImageRecordDeleteException(Exception):
|
|
||||||
"""Raised when an image record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
@ -91,6 +60,8 @@ class ImageRecordStorageBase(ABC):
|
|||||||
"""Gets a page of image records."""
|
"""Gets a page of image records."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
"""Deletes an image record."""
|
"""Deletes an image record."""
|
||||||
@ -102,11 +73,12 @@ class ImageRecordStorageBase(ABC):
|
|||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
created_at: str = datetime.datetime.utcnow().isoformat(),
|
) -> datetime:
|
||||||
) -> None:
|
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -141,17 +113,23 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
# Create the `images` table.
|
# Create the `images` table.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS images (
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
id TEXT PRIMARY KEY,
|
image_name TEXT NOT NULL PRIMARY KEY,
|
||||||
image_type TEXT, -- non-nullable via foreign key constraint
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
image_category TEXT, -- non-nullable via foreign key constraint
|
image_type TEXT NOT NULL,
|
||||||
session_id TEXT, -- nullable
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
node_id TEXT, -- nullable
|
image_category TEXT NOT NULL,
|
||||||
metadata TEXT, -- nullable
|
width INTEGER NOT NULL,
|
||||||
created_at TEXT NOT NULL,
|
height INTEGER NOT NULL,
|
||||||
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
|
session_id TEXT,
|
||||||
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
|
node_id TEXT,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -159,7 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
# Create the `images` table indices.
|
# Create the `images` table indices.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -172,53 +150,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the tables for image-related enums
|
|
||||||
create_enum_table(
|
|
||||||
enum=ImageType,
|
|
||||||
table_name="image_types",
|
|
||||||
primary_key_name="type_name",
|
|
||||||
cursor=self._cursor,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_enum_table(
|
|
||||||
enum=ImageCategory,
|
|
||||||
table_name="image_categories",
|
|
||||||
primary_key_name="category_name",
|
|
||||||
cursor=self._cursor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS tags (
|
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
tag_name TEXT UNIQUE NOT NULL
|
|
||||||
);
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the `images_tags` junction table.
|
# Add trigger for `updated_at`.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE TABLE IF NOT EXISTS images_tags (
|
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||||
image_id TEXT,
|
AFTER UPDATE
|
||||||
tag_id INTEGER,
|
ON images FOR EACH ROW
|
||||||
PRIMARY KEY (image_id, tag_id),
|
BEGIN
|
||||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
|
UPDATE images SET updated_at = current_timestamp
|
||||||
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
|
WHERE image_name = old.image_name;
|
||||||
);
|
END;
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the `images_favorites` table.
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS images_favorites (
|
|
||||||
image_id TEXT PRIMARY KEY,
|
|
||||||
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -229,22 +176,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT * FROM images
|
SELECT * FROM images
|
||||||
WHERE id = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
(image_name,),
|
(image_name,),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self._cursor.fetchone()
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise self.ImageRecordNotFoundException from e
|
raise ImageRecordNotFoundException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
raise self.ImageRecordNotFoundException
|
raise ImageRecordNotFoundException
|
||||||
|
|
||||||
return deserialize_image_record(result)
|
return deserialize_image_record(dict(result))
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
@ -260,14 +207,15 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT * FROM images
|
SELECT * FROM images
|
||||||
WHERE image_type = ? AND image_category = ?
|
WHERE image_type = ? AND image_category = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
LIMIT ? OFFSET ?;
|
LIMIT ? OFFSET ?;
|
||||||
""",
|
""",
|
||||||
(image_type.value, image_category.value, per_page, page * per_page),
|
(image_type.value, image_category.value, per_page, page * per_page),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self._cursor.fetchall()
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
|
||||||
images = list(map(lambda r: deserialize_image_record(r), result))
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
@ -296,14 +244,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM images
|
DELETE FROM images
|
||||||
WHERE id = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
(image_name,),
|
(image_name,),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise ImageRecordStorageBase.ImageRecordDeleteException from e
|
raise ImageRecordDeleteException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
@ -313,10 +261,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
created_at: str,
|
) -> datetime:
|
||||||
) -> None:
|
|
||||||
try:
|
try:
|
||||||
metadata_json = (
|
metadata_json = (
|
||||||
None if metadata is None else metadata.json(exclude_none=True)
|
None if metadata is None else metadata.json(exclude_none=True)
|
||||||
@ -325,29 +274,44 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO images (
|
INSERT OR IGNORE INTO images (
|
||||||
id,
|
image_name,
|
||||||
image_type,
|
image_type,
|
||||||
image_category,
|
image_category,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata,
|
metadata
|
||||||
created_at
|
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
image_name,
|
image_name,
|
||||||
image_type.value,
|
image_type.value,
|
||||||
image_category.value,
|
image_category.value,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata_json,
|
metadata_json,
|
||||||
created_at,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT created_at
|
||||||
|
FROM images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
|
||||||
|
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||||
|
|
||||||
|
return created_at
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
|
raise ImageRecordSaveException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
@ -4,9 +4,17 @@ from typing import Optional, TYPE_CHECKING, Union
|
|||||||
import uuid
|
import uuid
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageType,
|
||||||
|
InvalidImageCategoryException,
|
||||||
|
InvalidImageTypeException,
|
||||||
|
)
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import (
|
||||||
|
ImageRecordDeleteException,
|
||||||
|
ImageRecordNotFoundException,
|
||||||
|
ImageRecordSaveException,
|
||||||
ImageRecordStorageBase,
|
ImageRecordStorageBase,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
@ -14,7 +22,12 @@ from invokeai.app.services.models.image_record import (
|
|||||||
ImageDTO,
|
ImageDTO,
|
||||||
image_record_to_dto,
|
image_record_to_dto,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
from invokeai.app.services.image_file_storage import (
|
||||||
|
ImageFileDeleteException,
|
||||||
|
ImageFileNotFoundException,
|
||||||
|
ImageFileSaveException,
|
||||||
|
ImageFileStorageBase,
|
||||||
|
)
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
@ -50,6 +63,11 @@ class ImageServiceABC(ABC):
|
|||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||||
|
"""Gets an image DTO."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||||
"""Gets an image's path"""
|
"""Gets an image's path"""
|
||||||
@ -62,11 +80,6 @@ class ImageServiceABC(ABC):
|
|||||||
"""Gets an image's or thumbnail's URL"""
|
"""Gets an image's or thumbnail's URL"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
|
||||||
"""Gets an image DTO."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
@ -83,26 +96,6 @@ class ImageServiceABC(ABC):
|
|||||||
"""Deletes an image."""
|
"""Deletes an image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
|
||||||
"""Adds a tag to an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
|
||||||
"""Removes a tag from an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
|
||||||
"""Favorites an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
|
||||||
"""Unfavorites an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceDependencies:
|
class ImageServiceDependencies:
|
||||||
"""Service dependencies for the ImageService."""
|
"""Service dependencies for the ImageService."""
|
||||||
@ -160,6 +153,12 @@ class ImageService(ImageServiceABC):
|
|||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
|
if image_type not in ImageType:
|
||||||
|
raise InvalidImageTypeException
|
||||||
|
|
||||||
|
if image_category not in ImageCategory:
|
||||||
|
raise InvalidImageCategoryException
|
||||||
|
|
||||||
image_name = self._create_image_name(
|
image_name = self._create_image_name(
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
@ -167,11 +166,25 @@ class ImageService(ImageServiceABC):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
timestamp = get_iso_timestamp()
|
|
||||||
metadata = self._get_metadata(session_id, node_id)
|
metadata = self._get_metadata(session_id, node_id)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
|
created_at = self._services.records.save(
|
||||||
|
# Non-nullable fields
|
||||||
|
image_name=image_name,
|
||||||
|
image_type=image_type,
|
||||||
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# Nullable fields
|
||||||
|
node_id=node_id,
|
||||||
|
session_id=session_id,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
self._services.files.save(
|
self._services.files.save(
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
@ -179,36 +192,34 @@ class ImageService(ImageServiceABC):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._services.records.save(
|
|
||||||
image_name=image_name,
|
|
||||||
image_type=image_type,
|
|
||||||
image_category=image_category,
|
|
||||||
node_id=node_id,
|
|
||||||
session_id=session_id,
|
|
||||||
metadata=metadata,
|
|
||||||
created_at=timestamp,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||||
thumbnail_url = self._services.urls.get_image_url(
|
thumbnail_url = self._services.urls.get_image_url(
|
||||||
image_type, image_name, True
|
image_type, image_name, True
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# Nullable fields
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
created_at=timestamp,
|
# Meta fields
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=created_at, # this is always the same as the created_at at this time
|
||||||
|
deleted_at=None,
|
||||||
|
# Extra non-nullable fields for DTO
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
)
|
)
|
||||||
except ImageRecordStorageBase.ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to save image record")
|
self._services.logger.error("Failed to save image record")
|
||||||
raise
|
raise
|
||||||
except ImageFileStorageBase.ImageFileSaveException:
|
except ImageFileSaveException:
|
||||||
self._services.logger.error("Failed to save image file")
|
self._services.logger.error("Failed to save image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -218,7 +229,7 @@ class ImageService(ImageServiceABC):
|
|||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_type, image_name)
|
return self._services.files.get(image_type, image_name)
|
||||||
except ImageFileStorageBase.ImageFileNotFoundException:
|
except ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self._services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -228,7 +239,7 @@ class ImageService(ImageServiceABC):
|
|||||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
return self._services.records.get(image_type, image_name)
|
return self._services.records.get(image_type, image_name)
|
||||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -246,7 +257,7 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -311,32 +322,19 @@ class ImageService(ImageServiceABC):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str):
|
def delete(self, image_type: ImageType, image_name: str):
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
|
||||||
try:
|
try:
|
||||||
self._services.files.delete(image_type, image_name)
|
self._services.files.delete(image_type, image_name)
|
||||||
self._services.records.delete(image_type, image_name)
|
self._services.records.delete(image_type, image_name)
|
||||||
except ImageRecordStorageBase.ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
except ImageFileStorageBase.ImageFileDeleteException:
|
except ImageFileDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image file")
|
self._services.logger.error(f"Failed to delete image file")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem deleting image record and file")
|
self._services.logger.error("Problem deleting image record and file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
|
||||||
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
|
|
||||||
|
|
||||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
|
||||||
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
|
|
||||||
|
|
||||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
|
||||||
raise NotImplementedError("The 'favorite' method is not implemented yet.")
|
|
||||||
|
|
||||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
|
||||||
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
|
|
||||||
|
|
||||||
def _create_image_name(
|
def _create_image_name(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import sqlite3
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
@ -10,30 +9,60 @@ from invokeai.app.util.misc import get_iso_timestamp
|
|||||||
class ImageRecord(BaseModel):
|
class ImageRecord(BaseModel):
|
||||||
"""Deserialized image record."""
|
"""Deserialized image record."""
|
||||||
|
|
||||||
image_name: str = Field(description="The name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
image_type: ImageType = Field(description="The type of the image.")
|
image_type: ImageType = Field(description="The type of the image.")
|
||||||
|
"""The type of the image."""
|
||||||
image_category: ImageCategory = Field(description="The category of the image.")
|
image_category: ImageCategory = Field(description="The category of the image.")
|
||||||
|
"""The category of the image."""
|
||||||
|
width: int = Field(description="The width of the image in px.")
|
||||||
|
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||||
|
height: int = Field(description="The height of the image in px.")
|
||||||
|
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||||
created_at: Union[datetime.datetime, str] = Field(
|
created_at: Union[datetime.datetime, str] = Field(
|
||||||
description="The created timestamp of the image."
|
description="The created timestamp of the image."
|
||||||
)
|
)
|
||||||
session_id: Optional[str] = Field(default=None, description="The session ID.")
|
"""The created timestamp of the image."""
|
||||||
node_id: Optional[str] = Field(default=None, description="The node ID.")
|
updated_at: Union[datetime.datetime, str] = Field(
|
||||||
metadata: Optional[ImageMetadata] = Field(
|
description="The updated timestamp of the image."
|
||||||
default=None, description="The image's metadata."
|
|
||||||
)
|
)
|
||||||
|
"""The updated timestamp of the image."""
|
||||||
|
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||||
|
description="The deleted timestamp of the image."
|
||||||
|
)
|
||||||
|
"""The deleted timestamp of the image."""
|
||||||
|
session_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The session ID that generated this image, if it is a generated image.",
|
||||||
|
)
|
||||||
|
"""The session ID that generated this image, if it is a generated image."""
|
||||||
|
node_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The node ID that generated this image, if it is a generated image.",
|
||||||
|
)
|
||||||
|
"""The node ID that generated this image, if it is a generated image."""
|
||||||
|
metadata: Optional[ImageMetadata] = Field(
|
||||||
|
default=None,
|
||||||
|
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||||
|
)
|
||||||
|
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModel):
|
class ImageUrlsDTO(BaseModel):
|
||||||
"""The URLs for an image and its thumbnaill"""
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
image_name: str = Field(description="The name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
image_type: ImageType = Field(description="The type of the image.")
|
image_type: ImageType = Field(description="The type of the image.")
|
||||||
|
"""The type of the image."""
|
||||||
image_url: str = Field(description="The URL of the image.")
|
image_url: str = Field(description="The URL of the image.")
|
||||||
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
|
"""The URL of the image."""
|
||||||
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
|
"""The URL of the image's thumbnail."""
|
||||||
|
|
||||||
|
|
||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
"""Deserialized image record with URLs."""
|
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -43,24 +72,29 @@ def image_record_to_dto(
|
|||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Converts an image record to an image DTO."""
|
"""Converts an image record to an image DTO."""
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
image_name=image_record.image_name,
|
**image_record.dict(),
|
||||||
image_type=image_record.image_type,
|
|
||||||
image_category=image_record.image_category,
|
|
||||||
created_at=image_record.created_at,
|
|
||||||
session_id=image_record.session_id,
|
|
||||||
node_id=image_record.node_id,
|
|
||||||
metadata=image_record.metadata,
|
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||||
"""Deserializes an image record."""
|
"""Deserializes an image record."""
|
||||||
|
|
||||||
image_dict = dict(image_row)
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||||
|
image_category = ImageCategory(
|
||||||
|
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||||
|
)
|
||||||
|
width = image_dict.get("width", 0)
|
||||||
|
height = image_dict.get("height", 0)
|
||||||
|
session_id = image_dict.get("session_id", None)
|
||||||
|
node_id = image_dict.get("node_id", None)
|
||||||
|
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||||
|
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||||
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
|
|
||||||
raw_metadata = image_dict.get("metadata")
|
raw_metadata = image_dict.get("metadata")
|
||||||
|
|
||||||
@ -70,13 +104,15 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
|||||||
metadata = None
|
metadata = None
|
||||||
|
|
||||||
return ImageRecord(
|
return ImageRecord(
|
||||||
image_name=image_dict.get("id", "unknown"),
|
image_name=image_name,
|
||||||
session_id=image_dict.get("session_id", None),
|
|
||||||
node_id=image_dict.get("node_id", None),
|
|
||||||
metadata=metadata,
|
|
||||||
image_type=image_type,
|
image_type=image_type,
|
||||||
image_category=ImageCategory(
|
image_category=image_category,
|
||||||
image_dict.get("image_category", ImageCategory.IMAGE.value)
|
width=width,
|
||||||
),
|
height=height,
|
||||||
created_at=image_dict.get("created_at", get_iso_timestamp()),
|
session_id=session_id,
|
||||||
|
node_id=node_id,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
deleted_at=deleted_at,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,10 @@ from enum import EnumMeta
|
|||||||
|
|
||||||
|
|
||||||
class MetaEnum(EnumMeta):
|
class MetaEnum(EnumMeta):
|
||||||
"""Metaclass to support `in` syntax value checking in String Enums"""
|
"""Metaclass to support additional features in Enums.
|
||||||
|
|
||||||
|
- `in` operator support: `'value' in MyEnum -> bool`
|
||||||
|
"""
|
||||||
|
|
||||||
def __contains__(cls, item):
|
def __contains__(cls, item):
|
||||||
try:
|
try:
|
Reference in New Issue
Block a user