feat(nodes): address feedback

- Address database feedback:
  - Remove all the extraneous tables. Only an `images` table now:
  - `image_type` and `image_category` are unrestricted strings. When creating images, the provided values are checked to ensure they are a valid type and category.
  - Add `updated_at` and `deleted_at` columns. `deleted_at` is currently unused.
  - Use SQLite's built-in timestamp features to populate these. Add a trigger to update `updated_at` when the row is updated. Currently no way to update a row.
  - Rename the `id` column in `images` to `image_name`
- Rename `ImageCategory.IMAGE` to `ImageCategory.GENERAL`
- Move all exceptions outside their base classes to make them more portable.
- Add `width` and `height` columns to the database. These store the actual dimensions of the image file, whereas the metadata's `width` and `height` refer to the respective generation parameters and are nullable.
- Make `deserialize_image_record` take a `dict` instead of `sqlite3.Row`
- Improve comments throughout
- Tidy up unused code/files and some minor organisation
This commit is contained in:
psychedelicious 2023-05-23 18:59:43 +10:00 committed by Kent Keirsey
parent 021e5a2aa3
commit 035425ef24
12 changed files with 273 additions and 846 deletions

View File

@ -30,7 +30,7 @@ async def upload_image(
image_type: ImageType,
request: Request,
response: Response,
image_category: ImageCategory = ImageCategory.IMAGE,
image_category: ImageCategory = ImageCategory.GENERAL,
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):

View File

@ -95,7 +95,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images_new.create(
image=generate_output.image,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
@ -119,7 +119,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# context.services.images_db.set(
# id=image_name,
# image_type=ImageType.RESULT,
# image_category=ImageCategory.IMAGE,
# image_category=ImageCategory.GENERAL,
# session_id=context.graph_execution_state_id,
# node_id=self.id,
# metadata=GeneratedImageOrLatentsMetadata(),

View File

@ -372,7 +372,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_dto = context.services.images_new.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
)

View File

@ -2,7 +2,7 @@ from enum import Enum
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from invokeai.app.util.enum import MetaEnum
from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum):
@ -13,20 +13,32 @@ class ImageType(str, Enum, metaclass=MetaEnum):
INTERMEDIATE = "intermediates"
class InvalidImageTypeException(ValueError):
"""Raised when a provided value is not a valid ImageType.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image type."):
super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
IMAGE = "image"
CONTROL_IMAGE = "control_image"
GENERAL = "general"
CONTROL = "control"
OTHER = "other"
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class InvalidImageCategoryException(ValueError):
"""Raised when a provided value is not a valid ImageCategory.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image category."):
super().__init__(message)
class ImageField(BaseModel):

View File

@ -26,50 +26,66 @@ class ImageMetadata(BaseModel):
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
cfg_scale: Optional[StrictFloat] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -1,28 +0,0 @@
# TODO: Make a new model for this
from enum import Enum
from invokeai.app.util.enum import MetaEnum
class ResourceType(str, Enum, metaclass=MetaEnum):
"""The type of a resource."""
IMAGES = "images"
TENSORS = "tensors"
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
# """The origin of a resource (eg image or tensor)."""
# RESULTS = "results"
# UPLOADS = "uploads"
# INTERMEDIATES = "intermediates"
class TensorKind(str, Enum, metaclass=MetaEnum):
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
IMAGE_LATENTS = "image_latents"
CONDITIONING = "conditioning"
OTHER = "other"

View File

@ -1,578 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"from abc import ABC, abstractmethod\n",
"from enum import Enum\n",
"import enum\n",
"import sqlite3\n",
"import threading\n",
"from typing import Optional, Type, TypeVar, Union\n",
"from PIL.Image import Image as PILImage\n",
"from pydantic import BaseModel, Field\n",
"from torch import Tensor"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class ResourceOrigin(str, Enum):\n",
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
"\n",
" RESULTS = \"results\"\n",
" UPLOADS = \"uploads\"\n",
" INTERMEDIATES = \"intermediates\"\n",
"\n",
"\n",
"class ImageKind(str, Enum):\n",
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE = \"image\"\n",
" CONTROL_IMAGE = \"control_image\"\n",
" OTHER = \"other\"\n",
"\n",
"\n",
"class TensorKind(str, Enum):\n",
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE_LATENTS = \"image_latents\"\n",
" CONDITIONING = \"conditioning\"\n",
" OTHER = \"other\"\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
" \"\"\"\n",
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
" \"\"\"\n",
"\n",
" delimiter = \", \"\n",
" values = [f\"('{e.value}')\" for e in enum]\n",
" return delimiter.join(values)\n",
"\n",
"\n",
"def create_sql_table_from_enum(\n",
" enum: Type[Enum],\n",
" table_name: str,\n",
" primary_key_name: str,\n",
" conn: sqlite3.Connection,\n",
" cursor: sqlite3.Cursor,\n",
" lock: threading.Lock,\n",
"):\n",
" \"\"\"\n",
" Creates and populates a table to be used as a functional enum.\n",
" \"\"\"\n",
"\n",
" try:\n",
" lock.acquire()\n",
"\n",
" values_string = create_sql_values_string_from_string_enum(enum)\n",
"\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
" {primary_key_name} TEXT PRIMARY KEY\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
"# create_sql_table_from_enum(\n",
"# enum=ResourceOrigin,\n",
"# table_name=\"resource_origins\",\n",
"# primary_key_name=\"origin_name\",\n",
"# conn=conn,\n",
"# cursor=cursor,\n",
"# lock=lock,\n",
"# )\n",
"\n",
"\n",
"\"\"\"\n",
"`image_kinds` functions as an enum for the ImageType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=ImageKind,\n",
" # table_name=\"image_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`tensor_kinds` functions as an enum for the TensorType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=TensorKind,\n",
" # table_name=\"tensor_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`images` stores all images, regardless of type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" image_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_results` stores additional data specific to `results` images.\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_results (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_intermediates` stores additional data specific to `intermediates` images\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_metadata` stores basic metadata for any image type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
" images_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors` table: stores references to tensor\n",
"\n",
"\n",
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" tensor_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_results` stores additional data specific to `result` tensor\n",
"\n",
"\n",
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
"\n",
"\n",
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
"\n",
"\n",
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
"if (os.path.exists(db_path)):\n",
" os.remove(db_path)\n",
"\n",
"conn = sqlite3.connect(\n",
" db_path, check_same_thread=False\n",
")\n",
"cursor = conn.cursor()\n",
"lock = threading.Lock()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"create_sql_table_from_enum(\n",
" enum=ResourceOrigin,\n",
" table_name=\"resource_origins\",\n",
" primary_key_name=\"origin_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=ImageKind,\n",
" table_name=\"image_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=TensorKind,\n",
" table_name=\"tensor_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"create_images_table(conn, cursor, lock)\n",
"create_images_results_table(conn, cursor, lock)\n",
"create_images_intermediates_table(conn, cursor, lock)\n",
"create_images_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"create_tensors_table(conn, cursor, lock)\n",
"create_tensors_results_table(conn, cursor, lock)\n",
"create_tensors_intermediates_table(conn, cursor, lock)\n",
"create_tensors_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pydantic import StrictStr\n",
"\n",
"\n",
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
"\n",
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
"\n",
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
" \"\"\"\n",
"\n",
" positive_conditioning: Optional[StrictStr] = Field(\n",
" default=None, description=\"The positive conditioning.\"\n",
" )\n",
" negative_conditioning: Optional[str] = Field(\n",
" default=None, description=\"The negative conditioning.\"\n",
" )\n",
" width: Optional[int] = Field(\n",
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
" )\n",
" height: Optional[int] = Field(\n",
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
" )\n",
" seed: Optional[int] = Field(\n",
" default=None, description=\"The seed used for noise generation.\"\n",
" )\n",
" cfg_scale: Optional[float] = Field(\n",
" default=None, description=\"The classifier-free guidance scale.\"\n",
" )\n",
" steps: Optional[int] = Field(\n",
" default=None, description=\"The number of steps used for inference.\"\n",
" )\n",
" scheduler: Optional[str] = Field(\n",
" default=None, description=\"The scheduler used for inference.\"\n",
" )\n",
" model: Optional[str] = Field(\n",
" default=None, description=\"The model used for inference.\"\n",
" )\n",
" strength: Optional[float] = Field(\n",
" default=None,\n",
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
" )\n",
" image: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial image.\"\n",
" )\n",
" tensor: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial tensor.\"\n",
" )\n",
" # Pending model refactor:\n",
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -14,27 +14,31 @@ from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
# TODO: Should these excpetions subclass existing python exceptions?
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
@ -102,7 +106,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__set_cache(image_path, image)
return image
except FileNotFoundError as e:
raise ImageFileStorageBase.ImageFileNotFoundException from e
raise ImageFileNotFoundException from e
def save(
self,
@ -130,7 +134,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
except Exception as e:
raise ImageFileStorageBase.ImageFileSaveException from e
raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
@ -150,7 +154,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageFileStorageBase.ImageFileDeleteException from e
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
import datetime
from enum import Enum
from typing import Optional, Type
from datetime import datetime
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
@ -18,62 +17,32 @@ from invokeai.app.services.models.image_record import (
from invokeai.app.services.item_storage import PaginatedResults
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
"""
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
"""
# TODO: Should these excpetions subclass existing python exceptions?
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
delimiter = ", "
values = [f"('{e.value}')" for e in enum]
return delimiter.join(values)
def __init__(self, message="Image record not found"):
super().__init__(message)
def create_enum_table(
enum: Type[Enum],
table_name: str,
primary_key_name: str,
cursor: sqlite3.Cursor,
):
"""
Creates and populates a table to be used as a functional enum.
"""
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
values_string = create_sql_values_string_from_string_enum(enum)
def __init__(self, message="Image record not saved"):
super().__init__(message)
cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS {table_name} (
{primary_key_name} TEXT PRIMARY KEY
);
"""
)
cursor.execute(
f"""--sql
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
"""
)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
@ -91,6 +60,8 @@ class ImageRecordStorageBase(ABC):
"""Gets a page of image records."""
pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image record."""
@ -102,11 +73,12 @@ class ImageRecordStorageBase(ABC):
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
width: int,
height: int,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
created_at: str = datetime.datetime.utcnow().isoformat(),
) -> None:
) -> datetime:
"""Saves an image record."""
pass
@ -141,17 +113,23 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Create the `images` table.
self._cursor.execute(
f"""--sql
"""--sql
CREATE TABLE IF NOT EXISTS images (
id TEXT PRIMARY KEY,
image_type TEXT, -- non-nullable via foreign key constraint
image_category TEXT, -- non-nullable via foreign key constraint
session_id TEXT, -- nullable
node_id TEXT, -- nullable
metadata TEXT, -- nullable
created_at TEXT NOT NULL,
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_type TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
@ -159,7 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
@ -172,53 +150,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
# Create the tables for image-related enums
create_enum_table(
enum=ImageType,
table_name="image_types",
primary_key_name="type_name",
cursor=self._cursor,
)
create_enum_table(
enum=ImageCategory,
table_name="image_categories",
primary_key_name="category_name",
cursor=self._cursor,
)
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tag_name TEXT UNIQUE NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
# Create the `images_tags` junction table.
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_tags (
image_id TEXT,
tag_id INTEGER,
PRIMARY KEY (image_id, tag_id),
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
);
"""
)
# Create the `images_favorites` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_favorites (
image_id TEXT PRIMARY KEY,
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
);
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = current_timestamp
WHERE image_name = old.image_name;
END;
"""
)
@ -229,22 +176,22 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE id = ?;
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise self.ImageRecordNotFoundException from e
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise self.ImageRecordNotFoundException
raise ImageRecordNotFoundException
return deserialize_image_record(result)
return deserialize_image_record(dict(result))
def get_many(
self,
@ -260,14 +207,15 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
f"""--sql
SELECT * FROM images
WHERE image_type = ? AND image_category = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(image_type.value, image_category.value, per_page, page * per_page),
)
result = self._cursor.fetchall()
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(r), result))
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
@ -296,14 +244,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE id = ?;
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordStorageBase.ImageRecordDeleteException from e
raise ImageRecordDeleteException from e
finally:
self._lock.release()
@ -313,10 +261,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_type: ImageType,
image_category: ImageCategory,
session_id: Optional[str],
width: int,
height: int,
node_id: Optional[str],
metadata: Optional[ImageMetadata],
created_at: str,
) -> None:
) -> datetime:
try:
metadata_json = (
None if metadata is None else metadata.json(exclude_none=True)
@ -325,29 +274,44 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
id,
image_name,
image_type,
image_category,
width,
height,
node_id,
session_id,
metadata,
created_at
metadata
)
VALUES (?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_type.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata_json,
created_at,
),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
raise ImageRecordSaveException from e
finally:
self._lock.release()

View File

@ -4,9 +4,17 @@ from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.image import (
ImageCategory,
ImageType,
InvalidImageCategoryException,
InvalidImageTypeException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
)
from invokeai.app.services.models.image_record import (
@ -14,7 +22,12 @@ from invokeai.app.services.models.image_record import (
ImageDTO,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
@ -50,6 +63,11 @@ class ImageServiceABC(ABC):
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path"""
@ -62,11 +80,6 @@ class ImageServiceABC(ABC):
"""Gets an image's or thumbnail's URL"""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_many(
self,
@ -83,26 +96,6 @@ class ImageServiceABC(ABC):
"""Deletes an image."""
pass
@abstractmethod
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
pass
@abstractmethod
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
pass
@abstractmethod
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
pass
@abstractmethod
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
@ -160,6 +153,12 @@ class ImageService(ImageServiceABC):
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> ImageDTO:
if image_type not in ImageType:
raise InvalidImageTypeException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._create_image_name(
image_type=image_type,
image_category=image_category,
@ -167,11 +166,25 @@ class ImageService(ImageServiceABC):
session_id=session_id,
)
timestamp = get_iso_timestamp()
metadata = self._get_metadata(session_id, node_id)
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
)
self._services.files.save(
image_type=image_type,
image_name=image_name,
@ -179,36 +192,34 @@ class ImageService(ImageServiceABC):
metadata=metadata,
)
self._services.records.save(
image_name=image_name,
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
)
image_url = self._services.urls.get_image_url(image_type, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True
)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordStorageBase.ImageRecordSaveException:
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileStorageBase.ImageFileSaveException:
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
@ -218,7 +229,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_type, image_name)
except ImageFileStorageBase.ImageFileNotFoundException:
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
@ -228,7 +239,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_type, image_name)
except ImageRecordStorageBase.ImageRecordNotFoundException:
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
@ -246,7 +257,7 @@ class ImageService(ImageServiceABC):
)
return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException:
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
@ -311,32 +322,19 @@ class ImageService(ImageServiceABC):
raise e
def delete(self, image_type: ImageType, image_name: str):
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileStorageBase.ImageFileDeleteException:
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
def favorite(self, image_type: ImageType, image_id: str) -> None:
raise NotImplementedError("The 'favorite' method is not implemented yet.")
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
def _create_image_name(
self,
image_type: ImageType,

View File

@ -1,5 +1,4 @@
import datetime
import sqlite3
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageType
@ -10,30 +9,60 @@ from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel):
"""Deserialized image record."""
image_name: str = Field(description="The name of the image.")
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
"""The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
session_id: Optional[str] = Field(default=None, description="The session ID.")
node_id: Optional[str] = Field(default=None, description="The node ID.")
metadata: Optional[ImageMetadata] = Field(
default=None, description="The image's metadata."
"""The created timestamp of the image."""
updated_at: Union[datetime.datetime, str] = Field(
description="The updated timestamp of the image."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field(
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
)
"""The session ID that generated this image, if it is a generated image."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this image, if it is a generated image.",
)
"""The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnaill"""
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The name of the image.")
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_url: str = Field(description="The URL of the image.")
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
"""The URL of the image's thumbnail."""
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record with URLs."""
"""Deserialized image record, enriched for the frontend with URLs."""
pass
@ -43,24 +72,29 @@ def image_record_to_dto(
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
image_name=image_record.image_name,
image_type=image_record.image_type,
image_category=image_record.image_category,
created_at=image_record.created_at,
session_id=image_record.session_id,
node_id=image_record.node_id,
metadata=image_record.metadata,
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
)
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
def deserialize_image_record(image_dict: dict) -> ImageRecord:
"""Deserializes an image record."""
image_dict = dict(image_row)
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)
node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
raw_metadata = image_dict.get("metadata")
@ -70,13 +104,15 @@ def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
metadata = None
return ImageRecord(
image_name=image_dict.get("id", "unknown"),
session_id=image_dict.get("session_id", None),
node_id=image_dict.get("node_id", None),
metadata=metadata,
image_name=image_name,
image_type=image_type,
image_category=ImageCategory(
image_dict.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image_dict.get("created_at", get_iso_timestamp()),
image_category=image_category,
width=width,
height=height,
session_id=session_id,
node_id=node_id,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
)

View File

@ -2,7 +2,10 @@ from enum import EnumMeta
class MetaEnum(EnumMeta):
"""Metaclass to support `in` syntax value checking in String Enums"""
"""Metaclass to support additional features in Enums.
- `in` operator support: `'value' in MyEnum -> bool`
"""
def __contains__(cls, item):
try: