mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): address feedback
- Address database feedback: - Remove all the extraneous tables. Only an `images` table now: - `image_type` and `image_category` are unrestricted strings. When creating images, the provided values are checked to ensure they are a valid type and category. - Add `updated_at` and `deleted_at` columns. `deleted_at` is currently unused. - Use SQLite's built-in timestamp features to populate these. Add a trigger to update `updated_at` when the row is updated. Currently no way to update a row. - Rename the `id` column in `images` to `image_name` - Rename `ImageCategory.IMAGE` to `ImageCategory.GENERAL` - Move all exceptions outside their base classes to make them more portable. - Add `width` and `height` columns to the database. These store the actual dimensions of the image file, whereas the metadata's `width` and `height` refer to the respective generation parameters and are nullable. - Make `deserialize_image_record` take a `dict` instead of `sqlite3.Row` - Improve comments throughout - Tidy up unused code/files and some minor organisation
This commit is contained in:
parent
021e5a2aa3
commit
035425ef24
@ -30,7 +30,7 @@ async def upload_image(
|
||||
image_type: ImageType,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = ImageCategory.IMAGE,
|
||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
|
@ -95,7 +95,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
image_dto = context.services.images_new.create(
|
||||
image=generate_output.image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.IMAGE,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
)
|
||||
@ -119,7 +119,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# context.services.images_db.set(
|
||||
# id=image_name,
|
||||
# image_type=ImageType.RESULT,
|
||||
# image_category=ImageCategory.IMAGE,
|
||||
# image_category=ImageCategory.GENERAL,
|
||||
# session_id=context.graph_execution_state_id,
|
||||
# node_id=self.id,
|
||||
# metadata=GeneratedImageOrLatentsMetadata(),
|
||||
|
@ -372,7 +372,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image_dto = context.services.images_new.create(
|
||||
image=image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.IMAGE,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.enum import MetaEnum
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
||||
@ -13,20 +13,32 @@ class ImageType(str, Enum, metaclass=MetaEnum):
|
||||
INTERMEDIATE = "intermediates"
|
||||
|
||||
|
||||
class InvalidImageTypeException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageType.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image type."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
||||
|
||||
IMAGE = "image"
|
||||
CONTROL_IMAGE = "control_image"
|
||||
GENERAL = "general"
|
||||
CONTROL = "control"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
try:
|
||||
ImageType(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageCategory.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
|
@ -26,50 +26,66 @@ class ImageMetadata(BaseModel):
|
||||
default=None,
|
||||
description="The type of the ancestor node of the image output node.",
|
||||
)
|
||||
"""The type of the ancestor node of the image output node."""
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
"""The positive conditioning"""
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
"""The negative conditioning"""
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/latents in pixels."
|
||||
)
|
||||
"""Width of the image/latents in pixels"""
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/latents in pixels."
|
||||
)
|
||||
"""Height of the image/latents in pixels"""
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
cfg_scale: Optional[StrictFloat] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
"""The number of steps used for inference"""
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
"""The scheduler used for inference"""
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
"""The model used for inference"""
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/latents-to-latents.",
|
||||
)
|
||||
"""The strength used for image-to-image/latents-to-latents."""
|
||||
latents: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial latents."
|
||||
)
|
||||
"""The ID of the initial latents"""
|
||||
vae: Optional[StrictStr] = Field(
|
||||
default=None, description="The VAE used for decoding."
|
||||
)
|
||||
"""The VAE used for decoding"""
|
||||
unet: Optional[StrictStr] = Field(
|
||||
default=None, description="The UNet used dor inference."
|
||||
)
|
||||
"""The UNet used dor inference"""
|
||||
clip: Optional[StrictStr] = Field(
|
||||
default=None, description="The CLIP Encoder used for conditioning."
|
||||
)
|
||||
"""The CLIP Encoder used for conditioning"""
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||
)
|
||||
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
||||
|
@ -1,28 +0,0 @@
|
||||
# TODO: Make a new model for this
|
||||
from enum import Enum
|
||||
|
||||
from invokeai.app.util.enum import MetaEnum
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=MetaEnum):
|
||||
"""The type of a resource."""
|
||||
|
||||
IMAGES = "images"
|
||||
TENSORS = "tensors"
|
||||
|
||||
|
||||
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
# """The origin of a resource (eg image or tensor)."""
|
||||
|
||||
# RESULTS = "results"
|
||||
# UPLOADS = "uploads"
|
||||
# INTERMEDIATES = "intermediates"
|
||||
|
||||
|
||||
|
||||
class TensorKind(str, Enum, metaclass=MetaEnum):
|
||||
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
|
||||
|
||||
IMAGE_LATENTS = "image_latents"
|
||||
CONDITIONING = "conditioning"
|
||||
OTHER = "other"
|
@ -1,578 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from abc import ABC, abstractmethod\n",
|
||||
"from enum import Enum\n",
|
||||
"import enum\n",
|
||||
"import sqlite3\n",
|
||||
"import threading\n",
|
||||
"from typing import Optional, Type, TypeVar, Union\n",
|
||||
"from PIL.Image import Image as PILImage\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from torch import Tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class ResourceOrigin(str, Enum):\n",
|
||||
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
|
||||
"\n",
|
||||
" RESULTS = \"results\"\n",
|
||||
" UPLOADS = \"uploads\"\n",
|
||||
" INTERMEDIATES = \"intermediates\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ImageKind(str, Enum):\n",
|
||||
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
|
||||
"\n",
|
||||
" IMAGE = \"image\"\n",
|
||||
" CONTROL_IMAGE = \"control_image\"\n",
|
||||
" OTHER = \"other\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class TensorKind(str, Enum):\n",
|
||||
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
|
||||
"\n",
|
||||
" IMAGE_LATENTS = \"image_latents\"\n",
|
||||
" CONDITIONING = \"conditioning\"\n",
|
||||
" OTHER = \"other\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" delimiter = \", \"\n",
|
||||
" values = [f\"('{e.value}')\" for e in enum]\n",
|
||||
" return delimiter.join(values)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_sql_table_from_enum(\n",
|
||||
" enum: Type[Enum],\n",
|
||||
" table_name: str,\n",
|
||||
" primary_key_name: str,\n",
|
||||
" conn: sqlite3.Connection,\n",
|
||||
" cursor: sqlite3.Cursor,\n",
|
||||
" lock: threading.Lock,\n",
|
||||
"):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates and populates a table to be used as a functional enum.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" values_string = create_sql_values_string_from_string_enum(enum)\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" f\"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
|
||||
" {primary_key_name} TEXT PRIMARY KEY\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" f\"\"\"--sql\n",
|
||||
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
"# create_sql_table_from_enum(\n",
|
||||
"# enum=ResourceOrigin,\n",
|
||||
"# table_name=\"resource_origins\",\n",
|
||||
"# primary_key_name=\"origin_name\",\n",
|
||||
"# conn=conn,\n",
|
||||
"# cursor=cursor,\n",
|
||||
"# lock=lock,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`image_kinds` functions as an enum for the ImageType model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" # create_sql_table_from_enum(\n",
|
||||
" # enum=ImageKind,\n",
|
||||
" # table_name=\"image_kinds\",\n",
|
||||
" # primary_key_name=\"kind_name\",\n",
|
||||
" # conn=conn,\n",
|
||||
" # cursor=cursor,\n",
|
||||
" # lock=lock,\n",
|
||||
" # )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`tensor_kinds` functions as an enum for the TensorType model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" # create_sql_table_from_enum(\n",
|
||||
" # enum=TensorKind,\n",
|
||||
" # table_name=\"tensor_kinds\",\n",
|
||||
" # primary_key_name=\"kind_name\",\n",
|
||||
" # conn=conn,\n",
|
||||
" # cursor=cursor,\n",
|
||||
" # lock=lock,\n",
|
||||
" # )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images` stores all images, regardless of type\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images (\n",
|
||||
" id TEXT PRIMARY KEY,\n",
|
||||
" origin TEXT,\n",
|
||||
" image_kind TEXT,\n",
|
||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
||||
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_results` stores additional data specific to `results` images.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_results (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_intermediates` stores additional data specific to `intermediates` images\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_metadata` stores basic metadata for any image type\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" metadata TEXT,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors` table: stores references to tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors (\n",
|
||||
" id TEXT PRIMARY KEY,\n",
|
||||
" origin TEXT,\n",
|
||||
" tensor_kind TEXT,\n",
|
||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
||||
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_results` stores additional data specific to `result` tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" metadata TEXT,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
|
||||
"if (os.path.exists(db_path)):\n",
|
||||
" os.remove(db_path)\n",
|
||||
"\n",
|
||||
"conn = sqlite3.connect(\n",
|
||||
" db_path, check_same_thread=False\n",
|
||||
")\n",
|
||||
"cursor = conn.cursor()\n",
|
||||
"lock = threading.Lock()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=ResourceOrigin,\n",
|
||||
" table_name=\"resource_origins\",\n",
|
||||
" primary_key_name=\"origin_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=ImageKind,\n",
|
||||
" table_name=\"image_kinds\",\n",
|
||||
" primary_key_name=\"kind_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=TensorKind,\n",
|
||||
" table_name=\"tensor_kinds\",\n",
|
||||
" primary_key_name=\"kind_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_images_table(conn, cursor, lock)\n",
|
||||
"create_images_results_table(conn, cursor, lock)\n",
|
||||
"create_images_intermediates_table(conn, cursor, lock)\n",
|
||||
"create_images_metadata_table(conn, cursor, lock)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_tensors_table(conn, cursor, lock)\n",
|
||||
"create_tensors_results_table(conn, cursor, lock)\n",
|
||||
"create_tensors_intermediates_table(conn, cursor, lock)\n",
|
||||
"create_tensors_metadata_table(conn, cursor, lock)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from pydantic import StrictStr\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
|
||||
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
|
||||
"\n",
|
||||
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
|
||||
"\n",
|
||||
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" positive_conditioning: Optional[StrictStr] = Field(\n",
|
||||
" default=None, description=\"The positive conditioning.\"\n",
|
||||
" )\n",
|
||||
" negative_conditioning: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The negative conditioning.\"\n",
|
||||
" )\n",
|
||||
" width: Optional[int] = Field(\n",
|
||||
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
|
||||
" )\n",
|
||||
" height: Optional[int] = Field(\n",
|
||||
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
|
||||
" )\n",
|
||||
" seed: Optional[int] = Field(\n",
|
||||
" default=None, description=\"The seed used for noise generation.\"\n",
|
||||
" )\n",
|
||||
" cfg_scale: Optional[float] = Field(\n",
|
||||
" default=None, description=\"The classifier-free guidance scale.\"\n",
|
||||
" )\n",
|
||||
" steps: Optional[int] = Field(\n",
|
||||
" default=None, description=\"The number of steps used for inference.\"\n",
|
||||
" )\n",
|
||||
" scheduler: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The scheduler used for inference.\"\n",
|
||||
" )\n",
|
||||
" model: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The model used for inference.\"\n",
|
||||
" )\n",
|
||||
" strength: Optional[float] = Field(\n",
|
||||
" default=None,\n",
|
||||
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
|
||||
" )\n",
|
||||
" image: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The ID of the initial image.\"\n",
|
||||
" )\n",
|
||||
" tensor: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The ID of the initial tensor.\"\n",
|
||||
" )\n",
|
||||
" # Pending model refactor:\n",
|
||||
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
|
||||
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
|
||||
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -14,27 +14,31 @@ from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
@ -102,7 +106,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
except FileNotFoundError as e:
|
||||
raise ImageFileStorageBase.ImageFileNotFoundException from e
|
||||
raise ImageFileNotFoundException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
@ -130,7 +134,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
except Exception as e:
|
||||
raise ImageFileStorageBase.ImageFileSaveException from e
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
try:
|
||||
@ -150,7 +154,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
raise ImageFileStorageBase.ImageFileDeleteException from e
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
|
@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Type
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
@ -18,62 +17,32 @@ from invokeai.app.services.models.image_record import (
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
|
||||
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
|
||||
"""
|
||||
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
|
||||
"""
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
delimiter = ", "
|
||||
values = [f"('{e.value}')" for e in enum]
|
||||
return delimiter.join(values)
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def create_enum_table(
|
||||
enum: Type[Enum],
|
||||
table_name: str,
|
||||
primary_key_name: str,
|
||||
cursor: sqlite3.Cursor,
|
||||
):
|
||||
"""
|
||||
Creates and populates a table to be used as a functional enum.
|
||||
"""
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
values_string = create_sql_values_string_from_string_enum(enum)
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{primary_key_name} TEXT PRIMARY KEY
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
|
||||
"""
|
||||
)
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
@ -91,6 +60,8 @@ class ImageRecordStorageBase(ABC):
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
@ -102,11 +73,12 @@ class ImageRecordStorageBase(ABC):
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
created_at: str = datetime.datetime.utcnow().isoformat(),
|
||||
) -> None:
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@ -141,17 +113,23 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id TEXT PRIMARY KEY,
|
||||
image_type TEXT, -- non-nullable via foreign key constraint
|
||||
image_category TEXT, -- non-nullable via foreign key constraint
|
||||
session_id TEXT, -- nullable
|
||||
node_id TEXT, -- nullable
|
||||
metadata TEXT, -- nullable
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
|
||||
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_type TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
@ -159,7 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
@ -172,53 +150,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the tables for image-related enums
|
||||
create_enum_table(
|
||||
enum=ImageType,
|
||||
table_name="image_types",
|
||||
primary_key_name="type_name",
|
||||
cursor=self._cursor,
|
||||
)
|
||||
|
||||
create_enum_table(
|
||||
enum=ImageCategory,
|
||||
table_name="image_categories",
|
||||
primary_key_name="category_name",
|
||||
cursor=self._cursor,
|
||||
)
|
||||
|
||||
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tag_name TEXT UNIQUE NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images_tags` junction table.
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_tags (
|
||||
image_id TEXT,
|
||||
tag_id INTEGER,
|
||||
PRIMARY KEY (image_id, tag_id),
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images_favorites` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_favorites (
|
||||
image_id TEXT PRIMARY KEY,
|
||||
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = current_timestamp
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
@ -229,22 +176,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE id = ?;
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = self._cursor.fetchone()
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise self.ImageRecordNotFoundException from e
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise self.ImageRecordNotFoundException
|
||||
raise ImageRecordNotFoundException
|
||||
|
||||
return deserialize_image_record(result)
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@ -260,14 +207,15 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(image_type.value, image_category.value, per_page, page * per_page),
|
||||
)
|
||||
|
||||
result = self._cursor.fetchall()
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
|
||||
images = list(map(lambda r: deserialize_image_record(r), result))
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -296,14 +244,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE id = ?;
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordStorageBase.ImageRecordDeleteException from e
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@ -313,10 +261,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
created_at: str,
|
||||
) -> None:
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
@ -325,29 +274,44 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
id,
|
||||
image_name,
|
||||
image_type,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
created_at
|
||||
metadata
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_type.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
created_at,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
@ -4,9 +4,17 @@ from typing import Optional, TYPE_CHECKING, Union
|
||||
import uuid
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
InvalidImageCategoryException,
|
||||
InvalidImageTypeException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
@ -14,7 +22,12 @@ from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||
from invokeai.app.services.image_file_storage import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
ImageFileStorageBase,
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
@ -50,6 +63,11 @@ class ImageServiceABC(ABC):
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
"""Gets an image's path"""
|
||||
@ -62,11 +80,6 @@ class ImageServiceABC(ABC):
|
||||
"""Gets an image's or thumbnail's URL"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
@ -83,26 +96,6 @@ class ImageServiceABC(ABC):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
"""Adds a tag to an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
"""Removes a tag from an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
"""Favorites an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
"""Unfavorites an image."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
@ -160,6 +153,12 @@ class ImageService(ImageServiceABC):
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
if image_type not in ImageType:
|
||||
raise InvalidImageTypeException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self._create_image_name(
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
@ -167,11 +166,25 @@ class ImageService(ImageServiceABC):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
timestamp = get_iso_timestamp()
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
created_at = self._services.records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
@ -179,36 +192,34 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.records.save(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
created_at=timestamp,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_type, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
created_at=timestamp,
|
||||
# Meta fields
|
||||
created_at=created_at,
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except ImageRecordStorageBase.ImageRecordSaveException:
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileStorageBase.ImageFileSaveException:
|
||||
except ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -218,7 +229,7 @@ class ImageService(ImageServiceABC):
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_type, image_name)
|
||||
except ImageFileStorageBase.ImageFileNotFoundException:
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -228,7 +239,7 @@ class ImageService(ImageServiceABC):
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_type, image_name)
|
||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -246,7 +257,7 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -311,32 +322,19 @@ class ImageService(ImageServiceABC):
|
||||
raise e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
try:
|
||||
self._services.files.delete(image_type, image_name)
|
||||
self._services.records.delete(image_type, image_name)
|
||||
except ImageRecordStorageBase.ImageRecordDeleteException:
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
except ImageFileStorageBase.ImageFileDeleteException:
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
|
||||
|
||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
|
||||
|
||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
raise NotImplementedError("The 'favorite' method is not implemented yet.")
|
||||
|
||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
|
||||
|
||||
def _create_image_name(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
import sqlite3
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
@ -10,30 +9,60 @@ from invokeai.app.util.misc import get_iso_timestamp
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
|
||||
image_name: str = Field(description="The name of the image.")
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
"""The type of the image."""
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
"""The category of the image."""
|
||||
width: int = Field(description="The width of the image in px.")
|
||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||
height: int = Field(description="The height of the image in px.")
|
||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
session_id: Optional[str] = Field(default=None, description="The session ID.")
|
||||
node_id: Optional[str] = Field(default=None, description="The node ID.")
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None, description="The image's metadata."
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(
|
||||
description="The updated timestamp of the image."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the image."
|
||||
)
|
||||
"""The deleted timestamp of the image."""
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The session ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The session ID that generated this image, if it is a generated image."""
|
||||
node_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None,
|
||||
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||
)
|
||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
"""The URLs for an image and its thumbnaill"""
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The name of the image.")
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
"""The type of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
"""The URL of the image's thumbnail."""
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record with URLs."""
|
||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||
|
||||
pass
|
||||
|
||||
@ -43,24 +72,29 @@ def image_record_to_dto(
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
image_name=image_record.image_name,
|
||||
image_type=image_record.image_type,
|
||||
image_category=image_record.image_category,
|
||||
created_at=image_record.created_at,
|
||||
session_id=image_record.session_id,
|
||||
node_id=image_record.node_id,
|
||||
metadata=image_record.metadata,
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
||||
image_dict = dict(image_row)
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
width = image_dict.get("width", 0)
|
||||
height = image_dict.get("height", 0)
|
||||
session_id = image_dict.get("session_id", None)
|
||||
node_id = image_dict.get("node_id", None)
|
||||
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
@ -70,13 +104,15 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
||||
metadata = None
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_dict.get("id", "unknown"),
|
||||
session_id=image_dict.get("session_id", None),
|
||||
node_id=image_dict.get("node_id", None),
|
||||
metadata=metadata,
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.IMAGE.value)
|
||||
),
|
||||
created_at=image_dict.get("created_at", get_iso_timestamp()),
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
session_id=session_id,
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
)
|
||||
|
@ -2,7 +2,10 @@ from enum import EnumMeta
|
||||
|
||||
|
||||
class MetaEnum(EnumMeta):
|
||||
"""Metaclass to support `in` syntax value checking in String Enums"""
|
||||
"""Metaclass to support additional features in Enums.
|
||||
|
||||
- `in` operator support: `'value' in MyEnum -> bool`
|
||||
"""
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
Loading…
Reference in New Issue
Block a user