diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 8c274ab8f9..21094f4e5c 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join( "created_at", "updated_at", "deleted_at", + "pinned" ], ) ) @@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC): node_id: Optional[str], metadata: Optional[dict], is_intermediate: bool = False, + pinned: bool = False ) -> datetime: """Saves an image record.""" pass @@ -200,6 +202,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) + self._cursor.execute("PRAGMA table_info(images)") + columns = [column[1] for column in self._cursor.fetchall()] + + if "pinned" not in columns: + self._cursor.execute( + """--sql + ALTER TABLE images ADD COLUMN pinned BOOLEAN DEFAULT FALSE; + """ + ) + # Create the `images` table indices. self._cursor.execute( """--sql @@ -222,6 +234,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_images_pinned ON images(pinned); + """ + ) + # Add trigger for `updated_at`. self._cursor.execute( """--sql @@ -321,6 +339,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): (changes.is_intermediate, image_name), ) + # Change the image's `pinned`` state + if changes.pinned is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET pinned = ? + WHERE image_name = ?; + """, + (changes.pinned, image_name), + ) + self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -500,6 +529,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id: Optional[str], metadata: Optional[dict], is_intermediate: bool = False, + pinned: bool = False ) -> datetime: try: metadata_json = None if metadata is None else json.dumps(metadata) @@ -515,9 +545,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id, session_id, metadata, - is_intermediate + is_intermediate, + pinned ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, @@ -529,6 +560,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id, metadata_json, is_intermediate, + pinned, ), ) self._conn.commit() diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 294b760630..27cbea011f 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -39,6 +39,8 @@ class ImageRecord(BaseModelExcludeNull): description="The node ID that generated this image, if it is a generated image.", ) """The node ID that generated this image, if it is a generated image.""" + pinned: bool = Field(description="Whether this image is pinned.") + """Whether this image is pinned.""" class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): @@ -48,6 +50,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): - `image_category`: change the category of an image - `session_id`: change the session associated with an image - `is_intermediate`: change the image's `is_intermediate` flag + - `pinned`: change whether the image is pinned """ image_category: Optional[ImageCategory] = Field(description="The image's new category.") @@ -59,6 +62,8 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): """The image's new session ID.""" is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.") """The image's new `is_intermediate` flag.""" + pinned: Optional[StrictBool] = Field(default=None, description="The image's new `pinned` state") + """The image's new `pinned` state.""" class ImageUrlsDTO(BaseModelExcludeNull): @@ -113,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) + pinned = image_dict.get("pinned", False) return ImageRecord( image_name=image_name, @@ -126,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at=updated_at, deleted_at=deleted_at, is_intermediate=is_intermediate, + pinned=pinned )