(api) add ability to pin and unpin images

This commit is contained in:
maryhipp 2023-08-11 08:21:05 -07:00 committed by psychedelicious
parent 2b7dd3e236
commit 04a9894e77
2 changed files with 41 additions and 2 deletions

View File

@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join(
"created_at",
"updated_at",
"deleted_at",
"pinned"
],
)
)
@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC):
node_id: Optional[str],
metadata: Optional[dict],
is_intermediate: bool = False,
pinned: bool = False
) -> datetime:
"""Saves an image record."""
pass
@ -200,6 +202,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "pinned" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN pinned BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
@ -222,6 +234,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_pinned ON images(pinned);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
@ -321,6 +339,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(changes.is_intermediate, image_name),
)
# Change the image's `pinned`` state
if changes.pinned is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET pinned = ?
WHERE image_name = ?;
""",
(changes.pinned, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@ -500,6 +529,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str],
metadata: Optional[dict],
is_intermediate: bool = False,
pinned: bool = False
) -> datetime:
try:
metadata_json = None if metadata is None else json.dumps(metadata)
@ -515,9 +545,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata,
is_intermediate
is_intermediate,
pinned
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -529,6 +560,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id,
metadata_json,
is_intermediate,
pinned,
),
)
self._conn.commit()

View File

@ -39,6 +39,8 @@ class ImageRecord(BaseModelExcludeNull):
description="The node ID that generated this image, if it is a generated image.",
)
"""The node ID that generated this image, if it is a generated image."""
pinned: bool = Field(description="Whether this image is pinned.")
"""Whether this image is pinned."""
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
@ -48,6 +50,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag
- `pinned`: change whether the image is pinned
"""
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
@ -59,6 +62,8 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
"""The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
"""The image's new `is_intermediate` flag."""
pinned: Optional[StrictBool] = Field(default=None, description="The image's new `pinned` state")
"""The image's new `pinned` state."""
class ImageUrlsDTO(BaseModelExcludeNull):
@ -113,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
pinned = image_dict.get("pinned", False)
return ImageRecord(
image_name=image_name,
@ -126,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
pinned=pinned
)