From 33a0af46375c38d28c576d03ef18835729fc8f52 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 09:10:02 +1000 Subject: [PATCH 01/34] feat(nodes): add nameservice Currenly only used to make names for images, but when latents, conditioning, etc are managed in DB, will do the same for them. Intended to eventually support custom naming schemes. --- invokeai/app/api/dependencies.py | 4 ++- invokeai/app/cli_app.py | 3 ++ invokeai/app/services/image_record_storage.py | 1 - invokeai/app/services/images.py | 30 +++++-------------- invokeai/app/services/resource_name.py | 30 +++++++++++++++++++ 5 files changed, 44 insertions(+), 24 deletions(-) create mode 100644 invokeai/app/services/resource_name.py diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index dfef5d2176..0dfb5505b6 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -5,6 +5,7 @@ import os from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService from invokeai.app.services.metadata import CoreMetadataService +from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService from invokeai.backend.util.logging import InvokeAILogger @@ -67,7 +68,7 @@ class ApiDependencies: metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") - + names = SimpleNameService() latents = ForwardCacheLatentsStorage( DiskLatentsStorage(f"{output_folder}/latents") ) @@ -78,6 +79,7 @@ class ApiDependencies: metadata=metadata, url=urls, logger=logger, + names=names, graph_execution_manager=graph_execution_manager, ) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index de543d2d85..eb55ba45d2 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -16,6 +16,7 @@ from pydantic.fields import Field from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService from invokeai.app.services.metadata import CoreMetadataService +from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService @@ -229,6 +230,7 @@ def invoke_cli(): metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") + names = SimpleNameService() images = ImageService( image_record_storage=image_record_storage, @@ -236,6 +238,7 @@ def invoke_cli(): metadata=metadata, url=urls, logger=logger, + names=names, graph_execution_manager=graph_execution_manager, ) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 188a411a6b..9a73b68e21 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -103,7 +103,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def __init__(self, filename: str) -> None: super().__init__() - self._filename = filename self._conn = sqlite3.connect(filename, check_same_thread=False) # Enable row factory to get rows as dictionaries (must be done before making the cursor!) diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index d0f7236fe2..6da7510702 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from logging import Logger +from os import name from typing import Optional, TYPE_CHECKING, Union import uuid from PIL.Image import Image as PILImageType @@ -31,6 +32,7 @@ from invokeai.app.services.image_file_storage import ( ) from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.metadata import MetadataServiceBase +from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.urls import UrlServiceBase if TYPE_CHECKING: @@ -120,6 +122,7 @@ class ImageServiceDependencies: metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger + names: NameServiceBase graph_execution_manager: ItemStorageABC["GraphExecutionState"] def __init__( @@ -129,6 +132,7 @@ class ImageServiceDependencies: metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, + names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): self.records = image_record_storage @@ -136,6 +140,7 @@ class ImageServiceDependencies: self.metadata = metadata self.urls = url self.logger = logger + self.names = names self.graph_execution_manager = graph_execution_manager @@ -149,6 +154,7 @@ class ImageService(ImageServiceABC): metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, + names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): self._services = ImageServiceDependencies( @@ -157,6 +163,7 @@ class ImageService(ImageServiceABC): metadata=metadata, url=url, logger=logger, + names=names, graph_execution_manager=graph_execution_manager, ) @@ -175,12 +182,7 @@ class ImageService(ImageServiceABC): if image_category not in ImageCategory: raise InvalidImageCategoryException - image_name = self._create_image_name( - image_type=image_type, - image_category=image_category, - node_id=node_id, - session_id=session_id, - ) + image_name = self._services.names.create_image_name() metadata = self._get_metadata(session_id, node_id) @@ -260,7 +262,6 @@ class ImageService(ImageServiceABC): except Exception as e: self._services.logger.error("Problem updating image record") raise e - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: try: @@ -378,21 +379,6 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem deleting image record and file") raise e - def _create_image_name( - self, - image_type: ImageType, - image_category: ImageCategory, - node_id: Optional[str] = None, - session_id: Optional[str] = None, - ) -> str: - """Create a unique image name.""" - uuid_str = str(uuid.uuid4()) - - if node_id is not None and session_id is not None: - return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" - - return f"{image_type.value}_{image_category.value}_{uuid_str}.png" - def _get_metadata( self, session_id: Optional[str] = None, node_id: Optional[str] = None ) -> Union[ImageMetadata, None]: diff --git a/invokeai/app/services/resource_name.py b/invokeai/app/services/resource_name.py new file mode 100644 index 0000000000..dd5a76cfc0 --- /dev/null +++ b/invokeai/app/services/resource_name.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from enum import Enum, EnumMeta +import uuid + + +class ResourceType(str, Enum, metaclass=EnumMeta): + """Enum for resource types.""" + + IMAGE = "image" + LATENT = "latent" + + +class NameServiceBase(ABC): + """Low-level service responsible for naming resources (images, latents, etc).""" + + # TODO: Add customizable naming schemes + @abstractmethod + def create_image_name(self) -> str: + """Creates a name for an image.""" + pass + + +class SimpleNameService(NameServiceBase): + """Creates image names from UUIDs.""" + + # TODO: Add customizable naming schemes + def create_image_name(self) -> str: + uuid_str = str(uuid.uuid4()) + filename = f"{uuid_str}.png" + return filename From ee0225f4baa27839e61c4d70e9b51eb02a2da858 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 09:17:06 +1000 Subject: [PATCH 02/34] fix(nodes): handle intermediates during `images.get_many()` --- invokeai/app/api/routers/images.py | 4 ++++ invokeai/app/services/image_record_storage.py | 10 ++++++---- invokeai/app/services/images.py | 2 ++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 920181ff8b..6208b02d6f 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -230,6 +230,9 @@ async def get_image_urls( async def list_images_with_metadata( image_type: ImageType = Query(description="The type of images to list"), image_category: ImageCategory = Query(description="The kind of images to list"), + is_intermediate: bool = Query( + default=False, description="The kind of images to list" + ), page: int = Query(default=0, description="The page of image metadata to get"), per_page: int = Query( default=10, description="The number of image metadata per page" @@ -240,6 +243,7 @@ async def list_images_with_metadata( image_dtos = ApiDependencies.invoker.services.images.get_many( image_type, image_category, + is_intermediate, page, per_page, ) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 9a73b68e21..e0c97363f4 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -65,6 +65,7 @@ class ImageRecordStorageBase(ABC): self, image_type: ImageType, image_category: ImageCategory, + is_intermediate: bool = False, page: int = 0, per_page: int = 10, ) -> PaginatedResults[ImageRecord]: @@ -245,6 +246,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self, image_type: ImageType, image_category: ImageCategory, + is_intermediate: bool = False, page: int = 0, per_page: int = 10, ) -> PaginatedResults[ImageRecord]: @@ -254,11 +256,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self._cursor.execute( f"""--sql SELECT * FROM images - WHERE image_type = ? AND image_category = ? + WHERE image_type = ? AND image_category = ? AND is_intermediate = ? ORDER BY created_at DESC LIMIT ? OFFSET ?; """, - (image_type.value, image_category.value, per_page, page * per_page), + (image_type.value, image_category.value, is_intermediate, per_page, page * per_page), ) result = cast(list[sqlite3.Row], self._cursor.fetchall()) @@ -268,9 +270,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self._cursor.execute( """--sql SELECT count(*) FROM images - WHERE image_type = ? AND image_category = ? + WHERE image_type = ? AND image_category = ? AND is_intermediate = ? """, - (image_type.value, image_category.value), + (image_type.value, image_category.value, is_intermediate), ) count = self._cursor.fetchone()[0] diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 6da7510702..bfb7977890 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -330,6 +330,7 @@ class ImageService(ImageServiceABC): self, image_type: ImageType, image_category: ImageCategory, + is_intermediate: bool = False, page: int = 0, per_page: int = 10, ) -> PaginatedResults[ImageDTO]: @@ -337,6 +338,7 @@ class ImageService(ImageServiceABC): results = self._services.records.get_many( image_type, image_category, + is_intermediate, page, per_page, ) From f51defeeb3a11d49c7ba449caf228ca590f5157a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 09:17:17 +1000 Subject: [PATCH 03/34] chore(ui): regen api client --- .../frontend/web/src/services/api/services/ImagesService.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index d01a97a45e..51b3e2e88d 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -24,6 +24,7 @@ export class ImagesService { public static listImagesWithMetadata({ imageType, imageCategory, + isIntermediate = false, page, perPage = 10, }: { @@ -35,6 +36,10 @@ export class ImagesService { * The kind of images to list */ imageCategory: ImageCategory, + /** + * The kind of images to list + */ + isIntermediate?: boolean, /** * The page of image metadata to get */ @@ -50,6 +55,7 @@ export class ImagesService { query: { 'image_type': imageType, 'image_category': imageCategory, + 'is_intermediate': isIntermediate, 'page': page, 'per_page': perPage, }, From f609ee21a2bc34f32b23b866877101f6767c1789 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 09:17:35 +1000 Subject: [PATCH 04/34] fix(ui): handle intermediates when fetching gallery --- invokeai/frontend/web/src/services/thunks/gallery.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts index 11960e00d2..384dd7b8e2 100644 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ b/invokeai/frontend/web/src/services/thunks/gallery.ts @@ -25,6 +25,7 @@ export const receivedResultImagesPage = createAppAsyncThunk< const response = await ImagesService.listImagesWithMetadata({ imageType: 'results', imageCategory: 'general', + isIntermediate: false, page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); @@ -55,6 +56,7 @@ export const receivedUploadImagesPage = createAppAsyncThunk< const response = await ImagesService.listImagesWithMetadata({ imageType: 'uploads', imageCategory: 'general', + isIntermediate: false, page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); From 3ea5e78322c174b268ea15e3d733dd1c877e4ca2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 12:11:05 +1000 Subject: [PATCH 05/34] fix(nodes): fix list images route param descriptions --- invokeai/app/api/routers/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 6208b02d6f..0694540faf 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -231,7 +231,7 @@ async def list_images_with_metadata( image_type: ImageType = Query(description="The type of images to list"), image_category: ImageCategory = Query(description="The kind of images to list"), is_intermediate: bool = Query( - default=False, description="The kind of images to list" + default=False, description="Whether to list intermediate images" ), page: int = Query(default=0, description="The page of image metadata to get"), per_page: int = Query( From bdab73701fb4ad3010be62a81a153e3e66e1032f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 18:31:01 +1000 Subject: [PATCH 06/34] fix(ui): canvas images not added to staging --- .../listeners/socketio/invocationComplete.ts | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts index 95e6d831c0..60a3cdfedf 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts @@ -57,14 +57,7 @@ export const addInvocationCompleteListener = () => { graph_execution_state_id === getState().canvas.layerState.stagingArea.sessionId ) { - const [{ payload: image }] = await take( - ( - action - ): action is ReturnType => - imageMetadataReceived.fulfilled.match(action) && - action.payload.image_name === image_name - ); - dispatch(addImageToStagingArea(image)); + dispatch(addImageToStagingArea(imageDTO)); } dispatch(progressImageSet(null)); From 9317b42e5fbc88edbd5745b64f1d12390308cadf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 18:32:16 +1000 Subject: [PATCH 07/34] feat(nodes, ui): wip image types --- invokeai/app/api/routers/images.py | 36 +++++--- invokeai/app/services/image_record_storage.py | 92 ++++++++++++++----- invokeai/app/services/images.py | 28 +++--- invokeai/app/services/models/image_record.py | 4 + .../middleware/listenerMiddleware/index.ts | 12 +-- .../listeners/canvasMerged.ts | 3 + .../listeners/canvasSavedToGallery.ts | 19 +++- .../listeners/imageUploaded.ts | 15 +-- ...ImagesPage.ts => receivedGalleryImages.ts} | 14 +-- ...dImagesPage.ts => receivedUploadImages.ts} | 10 +- .../listeners/socketio/socketConnected.ts | 8 +- .../listeners/userInvokedCanvas.ts | 4 + .../common/components/ImageUploadOverlay.tsx | 4 +- .../src/common/components/ImageUploader.tsx | 15 +-- .../components/ImageGalleryContent.tsx | 12 +-- .../features/gallery/store/gallerySlice.ts | 8 +- .../features/gallery/store/resultsSlice.ts | 12 +-- .../features/gallery/store/uploadsSlice.ts | 15 ++- .../web/src/services/api/models/ImageDTO.ts | 4 + .../services/api/services/ImagesService.ts | 36 +++++--- .../web/src/services/thunks/gallery.ts | 9 +- .../frontend/web/src/services/thunks/image.ts | 9 +- 22 files changed, 218 insertions(+), 151 deletions(-) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/{receivedResultImagesPage.ts => receivedGalleryImages.ts} (60%) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/{receivedUploadImagesPage.ts => receivedUploadImages.ts} (73%) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 0694540faf..55556dd79a 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -34,11 +34,10 @@ async def upload_image( file: UploadFile, request: Request, response: Response, - image_category: ImageCategory = Query( - default=ImageCategory.GENERAL, description="The category of the image" - ), - is_intermediate: bool = Query( - default=False, description="Whether this is an intermediate image" + image_category: ImageCategory = Query(description="The category of the image"), + is_intermediate: bool = Query(description="Whether this is an intermediate image"), + show_in_gallery: bool = Query( + description="Whether this image should be shown in the gallery" ), session_id: Optional[str] = Query( default=None, description="The session ID associated with this upload, if any" @@ -63,6 +62,7 @@ async def upload_image( image_category=image_category, session_id=session_id, is_intermediate=is_intermediate, + show_in_gallery=show_in_gallery, ) response.status_code = 201 @@ -228,24 +228,30 @@ async def get_image_urls( response_model=PaginatedResults[ImageDTO], ) async def list_images_with_metadata( - image_type: ImageType = Query(description="The type of images to list"), - image_category: ImageCategory = Query(description="The kind of images to list"), - is_intermediate: bool = Query( - default=False, description="Whether to list intermediate images" + image_type: Optional[ImageType] = Query( + default=None, description="The type of images to list" ), - page: int = Query(default=0, description="The page of image metadata to get"), - per_page: int = Query( - default=10, description="The number of image metadata per page" + image_category: Optional[ImageCategory] = Query( + default=None, description="The kind of images to list" ), + is_intermediate: Optional[bool] = Query( + default=None, description="Whether to list intermediate images" + ), + show_in_gallery: Optional[bool] = Query( + default=None, description="Whether to list images that show in the gallery" + ), + page: int = Query(default=0, description="The page of images to get"), + per_page: int = Query(default=10, description="The number of images per page"), ) -> PaginatedResults[ImageDTO]: - """Gets a list of images with metadata""" + """Gets a list of images""" image_dtos = ApiDependencies.invoker.services.images.get_many( + page, + per_page, image_type, image_category, is_intermediate, - page, - per_page, + show_in_gallery, ) return image_dtos diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index e0c97363f4..b673acdf55 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -63,11 +63,12 @@ class ImageRecordStorageBase(ABC): @abstractmethod def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - is_intermediate: bool = False, page: int = 0, per_page: int = 10, + image_type: Optional[ImageType] = None, + image_category: Optional[ImageCategory] = None, + is_intermediate: Optional[bool] = None, + show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -91,6 +92,7 @@ class ImageRecordStorageBase(ABC): node_id: Optional[str], metadata: Optional[ImageMetadata], is_intermediate: bool = False, + show_in_gallery: bool = True, ) -> datetime: """Saves an image record.""" pass @@ -137,6 +139,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id TEXT, node_id TEXT, metadata TEXT, + show_in_gallery BOOLEAN DEFAULT TRUE, is_intermediate BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Updated via trigger @@ -224,7 +227,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """, (changes.image_category, image_name), ) - + # Change the session associated with the image if changes.session_id is not None: self._cursor.execute( @@ -244,36 +247,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - is_intermediate: bool = False, page: int = 0, per_page: int = 10, + image_type: Optional[ImageType] = None, + image_category: Optional[ImageCategory] = None, + is_intermediate: Optional[bool] = None, + show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageRecord]: try: self._lock.acquire() - self._cursor.execute( - f"""--sql - SELECT * FROM images - WHERE image_type = ? AND image_category = ? AND is_intermediate = ? - ORDER BY created_at DESC - LIMIT ? OFFSET ?; - """, - (image_type.value, image_category.value, is_intermediate, per_page, page * per_page), - ) + # Manually build two queries - one for the count, one for the records + + count_query = """--sql + SELECT COUNT(*) FROM images WHERE 1=1 + """ + + images_query = """--sql + SELECT * FROM images WHERE 1=1 + """ + + query_conditions = "" + query_params = [] + + if image_type is not None: + query_conditions += """--sql + AND image_type = ? + """ + query_params.append(image_type.value) + + if image_category is not None: + query_conditions += """--sql + AND image_category = ? + """ + query_params.append(image_category.value) + + if is_intermediate is not None: + query_conditions += """--sql + AND is_intermediate = ? + """ + query_params.append(is_intermediate) + + if show_in_gallery is not None: + query_conditions += """--sql + AND show_in_gallery = ? + """ + query_params.append(show_in_gallery) + + query_pagination = """--sql + ORDER BY created_at DESC LIMIT ? OFFSET ? + """ + + count_query += query_conditions + ";" + count_params = query_params.copy() + + images_query += query_conditions + query_pagination + ";" + images_params = query_params.copy() + images_params.append(per_page) + images_params.append(page * per_page) + + self._cursor.execute(images_query, images_params) result = cast(list[sqlite3.Row], self._cursor.fetchall()) images = list(map(lambda r: deserialize_image_record(dict(r)), result)) - self._cursor.execute( - """--sql - SELECT count(*) FROM images - WHERE image_type = ? AND image_category = ? AND is_intermediate = ? - """, - (image_type.value, image_category.value, is_intermediate), - ) + self._cursor.execute(count_query, count_params) count = self._cursor.fetchone()[0] except sqlite3.Error as e: @@ -316,6 +355,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id: Optional[str], metadata: Optional[ImageMetadata], is_intermediate: bool = False, + show_in_gallery: bool = True, ) -> datetime: try: metadata_json = ( @@ -333,9 +373,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id, session_id, metadata, - is_intermediate + is_intermediate, + show_in_gallery ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, @@ -347,6 +388,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id, metadata_json, is_intermediate, + show_in_gallery, ), ) self._conn.commit() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index bfb7977890..1bde1acfd4 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod from logging import Logger -from os import name from typing import Optional, TYPE_CHECKING, Union -import uuid from PIL.Image import Image as PILImageType from invokeai.app.models.image import ( @@ -51,6 +49,7 @@ class ImageServiceABC(ABC): node_id: Optional[str] = None, session_id: Optional[str] = None, intermediate: bool = False, + show_in_gallery: bool = True, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -100,10 +99,12 @@ class ImageServiceABC(ABC): @abstractmethod def get_many( self, - image_type: ImageType, - image_category: ImageCategory, page: int = 0, per_page: int = 10, + image_type: Optional[ImageType] = None, + image_category: Optional[ImageCategory] = None, + is_intermediate: Optional[bool] = None, + show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @@ -175,6 +176,7 @@ class ImageService(ImageServiceABC): node_id: Optional[str] = None, session_id: Optional[str] = None, is_intermediate: bool = False, + show_in_gallery: bool = True, ) -> ImageDTO: if image_type not in ImageType: raise InvalidImageTypeException @@ -199,6 +201,7 @@ class ImageService(ImageServiceABC): height=height, # Meta fields is_intermediate=is_intermediate, + show_in_gallery=show_in_gallery, # Nullable fields node_id=node_id, session_id=session_id, @@ -233,6 +236,7 @@ class ImageService(ImageServiceABC): updated_at=created_at, # this is always the same as the created_at at this time deleted_at=None, is_intermediate=is_intermediate, + show_in_gallery=show_in_gallery, # Extra non-nullable fields for DTO image_url=image_url, thumbnail_url=thumbnail_url, @@ -328,28 +332,30 @@ class ImageService(ImageServiceABC): def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - is_intermediate: bool = False, page: int = 0, per_page: int = 10, + image_type: Optional[ImageType] = None, + image_category: Optional[ImageCategory] = None, + is_intermediate: Optional[bool] = None, + show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageDTO]: try: results = self._services.records.get_many( + page, + per_page, image_type, image_category, is_intermediate, - page, - per_page, + show_in_gallery, ) image_dtos = list( map( lambda r: image_record_to_dto( r, - self._services.urls.get_image_url(image_type, r.image_name), + self._services.urls.get_image_url(r.image_type, r.image_name), self._services.urls.get_image_url( - image_type, r.image_name, True + r.image_type, r.image_name, True ), ), results.items, diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 26e4929be2..faa6e1b41a 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -33,6 +33,8 @@ class ImageRecord(BaseModel): """The deleted timestamp of the image.""" is_intermediate: bool = Field(description="Whether this is an intermediate image.") """Whether this is an intermediate image.""" + show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.") + """Whether this image should be shown in the gallery.""" session_id: Optional[str] = Field( default=None, description="The session ID that generated this image, if it is a generated image.", @@ -117,6 +119,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) + show_in_gallery = image_dict.get("show_in_gallery", True) raw_metadata = image_dict.get("metadata") @@ -138,4 +141,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at=updated_at, deleted_at=deleted_at, is_intermediate=is_intermediate, + show_in_gallery=show_in_gallery, ) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 1fbc2f978c..b669becfe6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -60,13 +60,13 @@ import { addSessionCanceledRejectedListener, } from './listeners/sessionCanceled'; import { - addReceivedResultImagesPageFulfilledListener, - addReceivedResultImagesPageRejectedListener, -} from './listeners/receivedResultImagesPage'; + addReceivedGalleryImagesFulfilledListener, + addReceivedGalleryImagesRejectedListener, +} from './listeners/receivedGalleryImages'; import { addReceivedUploadImagesPageFulfilledListener, addReceivedUploadImagesPageRejectedListener, -} from './listeners/receivedUploadImagesPage'; +} from './listeners/receivedUploadImages'; export const listenerMiddleware = createListenerMiddleware(); @@ -146,7 +146,7 @@ addSessionCanceledFulfilledListener(); addSessionCanceledRejectedListener(); // Gallery pages -addReceivedResultImagesPageFulfilledListener(); -addReceivedResultImagesPageRejectedListener(); +addReceivedGalleryImagesFulfilledListener(); +addReceivedGalleryImagesRejectedListener(); addReceivedUploadImagesPageFulfilledListener(); addReceivedUploadImagesPageRejectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts index fbc9c9c225..fc4c7247cd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -55,6 +55,9 @@ export const addCanvasMergedListener = () => { formData: { file: new File([blob], filename, { type: 'image/png' }), }, + imageCategory: 'general', + isIntermediate: true, + showInGallery: false, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index 2df3dacea2..7656e58b57 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -4,13 +4,15 @@ import { log } from 'app/logging/useLogger'; import { imageUploaded } from 'services/thunks/image'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { addToast } from 'features/system/store/systemSlice'; +import { v4 as uuidv4 } from 'uuid'; +import { resultUpserted } from 'features/gallery/store/resultsSlice'; const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); export const addCanvasSavedToGalleryListener = () => { startAppListening({ actionCreator: canvasSavedToGallery, - effect: async (action, { dispatch, getState }) => { + effect: async (action, { dispatch, getState, take }) => { const state = getState(); const blob = await getBaseLayerBlob(state); @@ -27,13 +29,26 @@ export const addCanvasSavedToGalleryListener = () => { return; } + const filename = `mergedCanvas_${uuidv4()}.png`; + dispatch( imageUploaded({ formData: { - file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), + file: new File([blob], filename, { type: 'image/png' }), }, + imageCategory: 'general', + isIntermediate: false, + showInGallery: true, }) ); + + const [{ payload: uploadedImageDTO }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === filename + ); + + dispatch(resultUpserted(uploadedImageDTO)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 5b177eae91..3d69eb8f9a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -3,10 +3,7 @@ import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; import { addToast } from 'features/system/store/systemSlice'; -import { initialImageSelected } from 'features/parameters/store/actions'; -import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { resultUpserted } from 'features/gallery/store/resultsSlice'; -import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards'; import { log } from 'app/logging/useLogger'; const moduleLog = log.child({ namespace: 'image' }); @@ -24,7 +21,7 @@ export const addImageUploadedFulfilledListener = () => { const state = getState(); // Handle uploads - if (isUploadsImageDTO(image)) { + if (!image.show_in_gallery && image.image_type === 'uploads') { dispatch(uploadUpserted(image)); dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); @@ -32,19 +29,11 @@ export const addImageUploadedFulfilledListener = () => { if (state.gallery.shouldAutoSwitchToNewImages) { dispatch(imageSelected(image)); } - - if (action.meta.arg.activeTabName === 'img2img') { - dispatch(initialImageSelected(image)); - } - - if (action.meta.arg.activeTabName === 'unifiedCanvas') { - dispatch(setInitialCanvasImage(image)); - } } // Handle results // TODO: Can this ever happen? I don't think so... - if (isResultsImageDTO(image)) { + if (image.show_in_gallery) { dispatch(resultUpserted(image)); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedResultImagesPage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedGalleryImages.ts similarity index 60% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedResultImagesPage.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedGalleryImages.ts index bcdd11ef97..aba81e1e72 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedResultImagesPage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedGalleryImages.ts @@ -1,31 +1,31 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; -import { receivedResultImagesPage } from 'services/thunks/gallery'; +import { receivedGalleryImages } from 'services/thunks/gallery'; import { serializeError } from 'serialize-error'; const moduleLog = log.child({ namespace: 'gallery' }); -export const addReceivedResultImagesPageFulfilledListener = () => { +export const addReceivedGalleryImagesFulfilledListener = () => { startAppListening({ - actionCreator: receivedResultImagesPage.fulfilled, + actionCreator: receivedGalleryImages.fulfilled, effect: (action, { getState, dispatch }) => { const page = action.payload; moduleLog.debug( { data: { page } }, - `Received ${page.items.length} results` + `Received ${page.items.length} gallery images` ); }, }); }; -export const addReceivedResultImagesPageRejectedListener = () => { +export const addReceivedGalleryImagesRejectedListener = () => { startAppListening({ - actionCreator: receivedResultImagesPage.rejected, + actionCreator: receivedGalleryImages.rejected, effect: (action, { getState, dispatch }) => { if (action.payload) { moduleLog.debug( { data: { error: serializeError(action.payload.error) } }, - 'Problem receiving results' + 'Problem receiving gallery images' ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImagesPage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts similarity index 73% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImagesPage.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts index 68813aae27..602fccf847 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImagesPage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts @@ -1,18 +1,18 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; -import { receivedUploadImagesPage } from 'services/thunks/gallery'; +import { receivedUploadImages } from 'services/thunks/gallery'; import { serializeError } from 'serialize-error'; const moduleLog = log.child({ namespace: 'gallery' }); export const addReceivedUploadImagesPageFulfilledListener = () => { startAppListening({ - actionCreator: receivedUploadImagesPage.fulfilled, + actionCreator: receivedUploadImages.fulfilled, effect: (action, { getState, dispatch }) => { const page = action.payload; moduleLog.debug( { data: { page } }, - `Received ${page.items.length} uploads` + `Received ${page.items.length} uploaded images` ); }, }); @@ -20,12 +20,12 @@ export const addReceivedUploadImagesPageFulfilledListener = () => { export const addReceivedUploadImagesPageRejectedListener = () => { startAppListening({ - actionCreator: receivedUploadImagesPage.rejected, + actionCreator: receivedUploadImages.rejected, effect: (action, { getState, dispatch }) => { if (action.payload) { moduleLog.debug( { data: { error: serializeError(action.payload.error) } }, - 'Problem receiving uploads' + 'Problem receiving uploaded images' ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index bc9ecbec1e..650918ba3c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -2,8 +2,8 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; import { socketConnected } from 'services/events/actions'; import { - receivedResultImagesPage, - receivedUploadImagesPage, + receivedGalleryImages, + receivedUploadImages, } from 'services/thunks/gallery'; import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; @@ -24,11 +24,11 @@ export const addSocketConnectedListener = () => { // These thunks need to be dispatch in middleware; cannot handle in a reducer if (!results.ids.length) { - dispatch(receivedResultImagesPage()); + dispatch(receivedGalleryImages()); } if (!uploads.ids.length) { - dispatch(receivedUploadImagesPage()); + dispatch(receivedUploadImages()); } if (!models.ids.length) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index ae388b85cf..bc1d5d5f8a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -101,7 +101,9 @@ export const addUserInvokedCanvasListener = () => { formData: { file: new File([baseBlob], baseFilename, { type: 'image/png' }), }, + imageCategory: 'general', isIntermediate: true, + showInGallery: false, }) ); @@ -127,7 +129,9 @@ export const addUserInvokedCanvasListener = () => { formData: { file: new File([maskBlob], maskFilename, { type: 'image/png' }), }, + imageCategory: 'mask', isIntermediate: true, + showInGallery: false, }) ); diff --git a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx index 28d9d32a71..862d806eb1 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx @@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook'; type ImageUploadOverlayProps = { isDragAccept: boolean; isDragReject: boolean; - overlaySecondaryText: string; setIsHandlingUpload: (isHandlingUpload: boolean) => void; }; @@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => { const { isDragAccept, isDragReject: _isDragAccept, - overlaySecondaryText, setIsHandlingUpload, } = props; @@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => { }} > {isDragAccept ? ( - Upload Image{overlaySecondaryText} + Drop to Upload ) : ( <> Invalid Upload diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx index 628d44b6f1..a4e6e52cb8 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx @@ -69,11 +69,13 @@ const ImageUploader = (props: ImageUploaderProps) => { dispatch( imageUploaded({ formData: { file }, - activeTabName, + imageCategory: 'general', + isIntermediate: false, + showInGallery: false, }) ); }, - [dispatch, activeTabName] + [dispatch] ); const onDrop = useCallback( @@ -144,14 +146,6 @@ const ImageUploader = (props: ImageUploaderProps) => { }; }, [inputRef, open, setOpenUploaderFunction]); - const overlaySecondaryText = useMemo(() => { - if (['img2img', 'unifiedCanvas'].includes(activeTabName)) { - return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`; - } - - return ''; - }, [t, activeTabName]); - return ( { )} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 468dfd694f..7c7fd29038 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -43,8 +43,8 @@ import HoverableImage from './HoverableImage'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { resultsAdapter } from '../store/resultsSlice'; import { - receivedResultImagesPage, - receivedUploadImagesPage, + receivedGalleryImages, + receivedUploadImages, } from 'services/thunks/gallery'; import { uploadsAdapter } from '../store/uploadsSlice'; import { createSelector } from '@reduxjs/toolkit'; @@ -151,11 +151,11 @@ const ImageGalleryContent = () => { const handleClickLoadMore = () => { if (currentCategory === 'results') { - dispatch(receivedResultImagesPage()); + dispatch(receivedGalleryImages()); } if (currentCategory === 'uploads') { - dispatch(receivedUploadImagesPage()); + dispatch(receivedUploadImages()); } }; @@ -211,9 +211,9 @@ const ImageGalleryContent = () => { const handleEndReached = useCallback(() => { if (currentCategory === 'results') { - dispatch(receivedResultImagesPage()); + dispatch(receivedGalleryImages()); } else if (currentCategory === 'uploads') { - dispatch(receivedUploadImagesPage()); + dispatch(receivedUploadImages()); } }, [dispatch, currentCategory]); diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 9d6f5ece60..1a49aeac1e 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,8 +1,8 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { - receivedResultImagesPage, - receivedUploadImagesPage, + receivedGalleryImages, + receivedUploadImages, } from '../../../services/thunks/gallery'; import { ImageDTO } from 'services/api'; @@ -60,7 +60,7 @@ export const gallerySlice = createSlice({ }, }, extraReducers(builder) { - builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => { + builder.addCase(receivedGalleryImages.fulfilled, (state, action) => { // rehydrate selectedImage URL when results list comes in // solves case when outdated URL is in local storage const selectedImage = state.selectedImage; @@ -76,7 +76,7 @@ export const gallerySlice = createSlice({ } } }); - builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => { + builder.addCase(receivedUploadImages.fulfilled, (state, action) => { // rehydrate selectedImage URL when results list comes in // solves case when outdated URL is in local storage const selectedImage = state.selectedImage; diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts index 36f4c49401..ad05284119 100644 --- a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts @@ -5,7 +5,7 @@ import { } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { - receivedResultImagesPage, + receivedGalleryImages, IMAGES_PER_PAGE, } from 'services/thunks/gallery'; import { ImageDTO } from 'services/api'; @@ -15,7 +15,7 @@ export type ResultsImageDTO = Omit & { image_type: 'results'; }; -export const resultsAdapter = createEntityAdapter({ +export const resultsAdapter = createEntityAdapter({ selectId: (image) => image.image_name, sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), }); @@ -43,7 +43,7 @@ const resultsSlice = createSlice({ name: 'results', initialState: initialResultsState, reducers: { - resultUpserted: (state, action: PayloadAction) => { + resultUpserted: (state, action: PayloadAction) => { resultsAdapter.upsertOne(state, action.payload); state.upsertedImageCount += 1; }, @@ -52,18 +52,18 @@ const resultsSlice = createSlice({ /** * Received Result Images Page - PENDING */ - builder.addCase(receivedResultImagesPage.pending, (state) => { + builder.addCase(receivedGalleryImages.pending, (state) => { state.isLoading = true; }); /** * Received Result Images Page - FULFILLED */ - builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => { + builder.addCase(receivedGalleryImages.fulfilled, (state, action) => { const { page, pages } = action.payload; // We know these will all be of the results type, but it's not represented in the API types - const items = action.payload.items as ResultsImageDTO[]; + const items = action.payload.items; resultsAdapter.setMany(state, items); diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts index 3058e82673..49e4d7e3ff 100644 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts @@ -5,10 +5,7 @@ import { } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; -import { - receivedUploadImagesPage, - IMAGES_PER_PAGE, -} from 'services/thunks/gallery'; +import { receivedUploadImages, IMAGES_PER_PAGE } from 'services/thunks/gallery'; import { ImageDTO } from 'services/api'; import { dateComparator } from 'common/util/dateComparator'; @@ -16,7 +13,7 @@ export type UploadsImageDTO = Omit & { image_type: 'uploads'; }; -export const uploadsAdapter = createEntityAdapter({ +export const uploadsAdapter = createEntityAdapter({ selectId: (image) => image.image_name, sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), }); @@ -44,7 +41,7 @@ const uploadsSlice = createSlice({ name: 'uploads', initialState: initialUploadsState, reducers: { - uploadUpserted: (state, action: PayloadAction) => { + uploadUpserted: (state, action: PayloadAction) => { uploadsAdapter.upsertOne(state, action.payload); state.upsertedImageCount += 1; }, @@ -53,18 +50,18 @@ const uploadsSlice = createSlice({ /** * Received Upload Images Page - PENDING */ - builder.addCase(receivedUploadImagesPage.pending, (state) => { + builder.addCase(receivedUploadImages.pending, (state) => { state.isLoading = true; }); /** * Received Upload Images Page - FULFILLED */ - builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => { + builder.addCase(receivedUploadImages.fulfilled, (state, action) => { const { page, pages } = action.payload; // We know these will all be of the uploads type, but it's not represented in the API types - const items = action.payload.items as UploadsImageDTO[]; + const items = action.payload.items; uploadsAdapter.setMany(state, items); diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts index bc2f19f1b5..599ca51de4 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts @@ -54,6 +54,10 @@ export type ImageDTO = { * Whether this is an intermediate image. */ is_intermediate: boolean; + /** + * Whether this image should be shown in the gallery. + */ + show_in_gallery: boolean; /** * The session ID that generated this image, if it is a generated image. */ diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index 51b3e2e88d..a8d22e802b 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -17,35 +17,40 @@ export class ImagesService { /** * List Images With Metadata - * Gets a list of images with metadata + * Gets a list of images * @returns PaginatedResults_ImageDTO_ Successful Response * @throws ApiError */ public static listImagesWithMetadata({ imageType, imageCategory, - isIntermediate = false, + isIntermediate, + showInGallery, page, perPage = 10, }: { /** * The type of images to list */ - imageType: ImageType, + imageType?: ImageType, /** * The kind of images to list */ - imageCategory: ImageCategory, + imageCategory?: ImageCategory, /** - * The kind of images to list + * Whether to list intermediate images */ isIntermediate?: boolean, /** - * The page of image metadata to get + * Whether to list images that show in the gallery + */ + showInGallery?: boolean, + /** + * The page of images to get */ page?: number, /** - * The number of image metadata per page + * The number of images per page */ perPage?: number, }): CancelablePromise { @@ -56,6 +61,7 @@ export class ImagesService { 'image_type': imageType, 'image_category': imageCategory, 'is_intermediate': isIntermediate, + 'show_in_gallery': showInGallery, 'page': page, 'per_page': perPage, }, @@ -72,20 +78,25 @@ export class ImagesService { * @throws ApiError */ public static uploadImage({ - formData, imageCategory, - isIntermediate = false, + isIntermediate, + showInGallery, + formData, sessionId, }: { - formData: Body_upload_image, /** * The category of the image */ - imageCategory?: ImageCategory, + imageCategory: ImageCategory, /** * Whether this is an intermediate image */ - isIntermediate?: boolean, + isIntermediate: boolean, + /** + * Whether this image should be shown in the gallery + */ + showInGallery: boolean, + formData: Body_upload_image, /** * The session ID associated with this upload, if any */ @@ -97,6 +108,7 @@ export class ImagesService { query: { 'image_category': imageCategory, 'is_intermediate': isIntermediate, + 'show_in_gallery': showInGallery, 'session_id': sessionId, }, formData: formData, diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts index 384dd7b8e2..03032a60ef 100644 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ b/invokeai/frontend/web/src/services/thunks/gallery.ts @@ -9,7 +9,7 @@ type ReceivedResultImagesPageThunkConfig = { }; }; -export const receivedResultImagesPage = createAppAsyncThunk< +export const receivedGalleryImages = createAppAsyncThunk< PaginatedResults_ImageDTO_, void, ReceivedResultImagesPageThunkConfig @@ -23,9 +23,8 @@ export const receivedResultImagesPage = createAppAsyncThunk< const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); const response = await ImagesService.listImagesWithMetadata({ - imageType: 'results', - imageCategory: 'general', isIntermediate: false, + showInGallery: true, page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); @@ -40,7 +39,7 @@ type ReceivedUploadImagesPageThunkConfig = { }; }; -export const receivedUploadImagesPage = createAppAsyncThunk< +export const receivedUploadImages = createAppAsyncThunk< PaginatedResults_ImageDTO_, void, ReceivedUploadImagesPageThunkConfig @@ -55,8 +54,8 @@ export const receivedUploadImagesPage = createAppAsyncThunk< const response = await ImagesService.listImagesWithMetadata({ imageType: 'uploads', - imageCategory: 'general', isIntermediate: false, + showInGallery: false, page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index f0c0456202..f324edad2b 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -32,11 +32,7 @@ export const imageMetadataReceived = createAppAsyncThunk( } ); -type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0] & { - // extra arg to determine post-upload actions - we check for this when the image is uploaded - // to determine if we should set the init image - activeTabName?: InvokeTabName; -}; +type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0]; /** * `ImagesService.uploadImage()` thunk @@ -45,8 +41,7 @@ export const imageUploaded = createAppAsyncThunk( 'api/imageUploaded', async (arg: ImageUploadedArg) => { // strip out `activeTabName` from arg - the route does not need it - const { activeTabName, ...rest } = arg; - const response = await ImagesService.uploadImage(rest); + const response = await ImagesService.uploadImage(arg); return response; } ); From fd47e70c929d7fa106affba778fac69d896a9d25 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 21:16:29 +1000 Subject: [PATCH 08/34] feat(nodes): use higher precision timestamps in db --- invokeai/app/services/image_record_storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index b673acdf55..8afa7000fb 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -141,9 +141,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): metadata TEXT, show_in_gallery BOOLEAN DEFAULT TRUE, is_intermediate BOOLEAN DEFAULT FALSE, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Soft delete, currently unused deleted_at DATETIME ); From 160267c71a281f155b03da9a7cb15ac72e0886a0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 21:39:20 +1000 Subject: [PATCH 09/34] feat(nodes): refactor image types - Remove `ImageType` entirely, it is confusing - Create `ResourceOrigin`, may be `internal` or `external` - Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on - Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change. The new setup should account for our types of images, including the combinations we couldn't really handle until now: - Canvas init and masks - Canvas when saved-to-gallery or merged --- invokeai/app/api/models/images.py | 39 ------ invokeai/app/api/routers/images.py | 71 +++++----- invokeai/app/invocations/cv.py | 10 +- invokeai/app/invocations/generate.py | 22 ++-- invokeai/app/invocations/image.py | 70 +++++----- invokeai/app/invocations/infill.py | 20 +-- invokeai/app/invocations/latent.py | 8 +- invokeai/app/invocations/reconstruct.py | 8 +- invokeai/app/invocations/upscale.py | 8 +- invokeai/app/models/image.py | 54 ++++++-- invokeai/app/services/events.py | 4 +- invokeai/app/services/image_file_storage.py | 38 +++--- invokeai/app/services/image_record_storage.py | 122 +++++++++--------- invokeai/app/services/images.py | 98 +++++++------- invokeai/app/services/models/image_record.py | 18 +-- invokeai/app/services/urls.py | 10 +- invokeai/app/util/step_callback.py | 2 +- 17 files changed, 291 insertions(+), 311 deletions(-) delete mode 100644 invokeai/app/api/models/images.py diff --git a/invokeai/app/api/models/images.py b/invokeai/app/api/models/images.py deleted file mode 100644 index fa04702326..0000000000 --- a/invokeai/app/api/models/images.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional -from pydantic import BaseModel, Field - -from invokeai.app.models.image import ImageType - - -class ImageResponseMetadata(BaseModel): - """An image's metadata. Used only in HTTP responses.""" - - created: int = Field(description="The creation timestamp of the image") - width: int = Field(description="The width of the image in pixels") - height: int = Field(description="The height of the image in pixels") - # invokeai: Optional[InvokeAIMetadata] = Field( - # description="The image's InvokeAI-specific metadata" - # ) - - -class ImageResponse(BaseModel): - """The response type for images""" - - image_type: ImageType = Field(description="The type of the image") - image_name: str = Field(description="The name of the image") - image_url: str = Field(description="The url of the image") - thumbnail_url: str = Field(description="The url of the image's thumbnail") - metadata: ImageResponseMetadata = Field(description="The image's metadata") - - -class ProgressImage(BaseModel): - """The progress image sent intermittently during processing""" - - width: int = Field(description="The effective width of the image in pixels") - height: int = Field(description="The effective height of the image in pixels") - dataURL: str = Field(description="The image data as a b64 data URL") - - -class SavedImage(BaseModel): - image_name: str = Field(description="The name of the saved image") - thumbnail_name: str = Field(description="The name of the saved thumbnail") - created: int = Field(description="The created timestamp of the saved image") diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 55556dd79a..f0399a2d07 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -6,7 +6,7 @@ from fastapi.responses import FileResponse from PIL import Image from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, ) from invokeai.app.services.models.image_record import ( ImageDTO, @@ -36,9 +36,6 @@ async def upload_image( response: Response, image_category: ImageCategory = Query(description="The category of the image"), is_intermediate: bool = Query(description="Whether this is an intermediate image"), - show_in_gallery: bool = Query( - description="Whether this image should be shown in the gallery" - ), session_id: Optional[str] = Query( default=None, description="The session ID associated with this upload, if any" ), @@ -58,11 +55,10 @@ async def upload_image( try: image_dto = ApiDependencies.invoker.services.images.create( image=pil_image, - image_type=ImageType.UPLOAD, + image_origin=ResourceOrigin.EXTERNAL, image_category=image_category, session_id=session_id, is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, ) response.status_code = 201 @@ -73,27 +69,27 @@ async def upload_image( raise HTTPException(status_code=500, detail="Failed to create image") -@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") +@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image") async def delete_image( - image_type: ImageType = Path(description="The type of image to delete"), + image_origin: ResourceOrigin = Path(description="The origin of image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image""" try: - ApiDependencies.invoker.services.images.delete(image_type, image_name) + ApiDependencies.invoker.services.images.delete(image_origin, image_name) except Exception as e: # TODO: Does this need any exception handling at all? pass @images_router.patch( - "/{image_type}/{image_name}", + "/{image_origin}/{image_name}", operation_id="update_image", response_model=ImageDTO, ) async def update_image( - image_type: ImageType = Path(description="The type of image to update"), + image_origin: ResourceOrigin = Path(description="The origin of image to update"), image_name: str = Path(description="The name of the image to update"), image_changes: ImageRecordChanges = Body( description="The changes to apply to the image" @@ -103,31 +99,31 @@ async def update_image( try: return ApiDependencies.invoker.services.images.update( - image_type, image_name, image_changes + image_origin, image_name, image_changes ) except Exception as e: raise HTTPException(status_code=400, detail="Failed to update image") @images_router.get( - "/{image_type}/{image_name}/metadata", + "/{image_origin}/{image_name}/metadata", operation_id="get_image_metadata", response_model=ImageDTO, ) async def get_image_metadata( - image_type: ImageType = Path(description="The type of image to get"), + image_origin: ResourceOrigin = Path(description="The origin of image to get"), image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images.get_dto(image_type, image_name) + return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name) except Exception as e: raise HTTPException(status_code=404) @images_router.get( - "/{image_type}/{image_name}", + "/{image_origin}/{image_name}", operation_id="get_image_full", response_class=Response, responses={ @@ -139,7 +135,7 @@ async def get_image_metadata( }, ) async def get_image_full( - image_type: ImageType = Path( + image_origin: ResourceOrigin = Path( description="The type of full-resolution image file to get" ), image_name: str = Path(description="The name of full-resolution image file to get"), @@ -147,7 +143,7 @@ async def get_image_full( """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images.get_path(image_type, image_name) + path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -163,7 +159,7 @@ async def get_image_full( @images_router.get( - "/{image_type}/{image_name}/thumbnail", + "/{image_origin}/{image_name}/thumbnail", operation_id="get_image_thumbnail", response_class=Response, responses={ @@ -175,14 +171,14 @@ async def get_image_full( }, ) async def get_image_thumbnail( - image_type: ImageType = Path(description="The type of thumbnail image file to get"), + image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"), ) -> FileResponse: """Gets a thumbnail image file""" try: path = ApiDependencies.invoker.services.images.get_path( - image_type, image_name, thumbnail=True + image_origin, image_name, thumbnail=True ) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -195,25 +191,25 @@ async def get_image_thumbnail( @images_router.get( - "/{image_type}/{image_name}/urls", + "/{image_origin}/{image_name}/urls", operation_id="get_image_urls", response_model=ImageUrlsDTO, ) async def get_image_urls( - image_type: ImageType = Path(description="The type of the image whose URL to get"), + image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" try: image_url = ApiDependencies.invoker.services.images.get_url( - image_type, image_name + image_origin, image_name ) thumbnail_url = ApiDependencies.invoker.services.images.get_url( - image_type, image_name, thumbnail=True + image_origin, image_name, thumbnail=True ) return ImageUrlsDTO( - image_type=image_type, + image_origin=image_origin, image_name=image_name, image_url=image_url, thumbnail_url=thumbnail_url, @@ -228,30 +224,33 @@ async def get_image_urls( response_model=PaginatedResults[ImageDTO], ) async def list_images_with_metadata( - image_type: Optional[ImageType] = Query( - default=None, description="The type of images to list" + image_origin: Optional[ResourceOrigin] = Query( + default=None, description="The origin of images to list" ), - image_category: Optional[ImageCategory] = Query( - default=None, description="The kind of images to list" + include_categories: Optional[list[ImageCategory]] = Query( + default=None, description="The categories of image to include" + ), + exclude_categories: Optional[list[ImageCategory]] = Query( + default=None, description="The categories of image to exclude" ), is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), - show_in_gallery: Optional[bool] = Query( - default=None, description="Whether to list images that show in the gallery" - ), page: int = Query(default=0, description="The page of images to get"), per_page: int = Query(default=10, description="The number of images per page"), ) -> PaginatedResults[ImageDTO]: """Gets a list of images""" + if include_categories is not None and exclude_categories is not None: + raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.") + image_dtos = ApiDependencies.invoker.services.images.get_many( page, per_page, - image_type, - image_category, + image_origin, + include_categories, + exclude_categories, is_intermediate, - show_in_gallery, ) return image_dtos diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5e9fe088b5..5275116a2a 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,7 +7,7 @@ import numpy from PIL import Image, ImageOps from pydantic import BaseModel, Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) mask = context.services.images.get_pil_image( - self.mask.image_type, self.mask.image_name + self.mask.image_origin, self.mask.image_name ) # Convert to cv image/mask @@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): image_dto = context.services.images.create( image=image_inpainted, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 0385c6a9f0..d2ce59d247 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -10,9 +10,9 @@ import torch from pydantic import BaseModel, Field -from invokeai.app.models.image import ColorField, ImageField, ImageType +from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin from invokeai.app.invocations.util.choose_model import choose_model -from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig @@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): image_dto = context.services.images.create( image=generate_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation): None if self.image is None else context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) ) @@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation): image_dto = context.services.images.create( image=generator_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation): None if self.image is None else context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) ) mask = ( None if self.mask is None - else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) + else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name) ) # Handle invalid model parameter @@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation): image_dto = context.services.images.create( image=generator_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 69d51e6158..7633bfbc16 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,7 +7,7 @@ import numpy from PIL import Image, ImageFilter, ImageOps, ImageChops from pydantic import BaseModel, Field -from ..models.image import ImageCategory, ImageField, ImageType +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation): ) # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) + image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) return ImageOutput( image=ImageField( image_name=self.image.image_name, - image_type=self.image.image_type, + image_origin=self.image.image_origin, ), width=image.width, height=image.height, @@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) if image: image.show() @@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=self.image.image_name, - image_type=self.image.image_type, + image_origin=self.image.image_origin, ), width=image.width, height=image.height, @@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_crop = Image.new( @@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_crop, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.services.images.get_pil_image( - self.base_image.image_type, self.base_image.image_name + self.base_image.image_origin, self.base_image.image_name ) image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) mask = ( None if self.mask is None else ImageOps.invert( context.services.images.get_pil_image( - self.mask.image_type, self.mask.image_name + self.mask.image_origin, self.mask.image_name ) ) ) @@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=new_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> MaskOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_mask = image.split()[-1] @@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_mask, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.MASK, node_id=self.id, session_id=context.graph_execution_state_id, @@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): return MaskOutput( mask=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.services.images.get_pil_image( - self.image1.image_type, self.image1.image_name + self.image1.image_origin, self.image1.image_name ) image2 = context.services.images.get_pil_image( - self.image2.image_type, self.image2.image_name + self.image2.image_origin, self.image2.image_name ) multiply_image = ImageChops.multiply(image1, image2) image_dto = context.services.images.create( image=multiply_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) channel_image = image.getchannel(self.channel) image_dto = context.services.images.create( image=channel_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) converted_image = image.convert(self.mode) image_dto = context.services.images.create( image=converted_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) blur = ( @@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=blur_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -409,7 +409,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -430,7 +430,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -440,7 +440,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=lerp_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -450,7 +450,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -471,7 +471,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -486,7 +486,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=ilerp_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -496,7 +496,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index ad60b62633..a06780c1f5 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.image_util.patchmatch import PatchMatch -from ..models.image import ColorField, ImageCategory, ImageField, ImageType +from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from .baseinvocation import ( BaseInvocation, InvocationContext, @@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) infilled = tile_fill_missing( @@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) if PatchMatch.patchmatch_available(): @@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4975b7b578..7085cfd308 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np -from ..services.image_file_storage import ImageType +from ..services.image_file_storage import ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput from .compel import ConditioningField @@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation): # and gnenerate unique image_name image_dto = context.services.images.create( image=image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation): # self.image.image_type, self.image.image_name # ) image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) # TODO: this only really needs the vae diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index db71e4201d..5313411400 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -2,7 +2,7 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], @@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 90c9e4bf4f..80e1567047 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,7 +4,7 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], @@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 46b50145aa..6d48f2dbb1 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -5,30 +5,52 @@ from pydantic import BaseModel, Field from invokeai.app.util.metaenum import MetaEnum -class ImageType(str, Enum, metaclass=MetaEnum): - """The type of an image.""" +class ResourceOrigin(str, Enum, metaclass=MetaEnum): + """The origin of a resource (eg image). - RESULT = "results" - UPLOAD = "uploads" + - INTERNAL: The resource was created by the application. + - EXTERNAL: The resource was not created by the application. + This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + """ + + INTERNAL = "internal" + """The resource was created by the application.""" + EXTERNAL = "external" + """The resource was not created by the application. + This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + """ -class InvalidImageTypeException(ValueError): - """Raised when a provided value is not a valid ImageType. +class InvalidOriginException(ValueError): + """Raised when a provided value is not a valid ResourceOrigin. Subclasses `ValueError`. """ - def __init__(self, message="Invalid image type."): + def __init__(self, message="Invalid resource origin."): super().__init__(message) class ImageCategory(str, Enum, metaclass=MetaEnum): - """The category of an image. Use ImageCategory.OTHER for non-default categories.""" + """The category of an image. + + - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose. + - MASK: The image is a mask image. + - CONTROL: The image is a ControlNet control image. + - USER: The image is a user-provide image. + - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes. + """ GENERAL = "general" - CONTROL = "control" + """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.""" MASK = "mask" + """MASK: The image is a mask image.""" + CONTROL = "control" + """CONTROL: The image is a ControlNet control image.""" + USER = "user" + """USER: The image is a user-provide image.""" OTHER = "other" + """OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.""" class InvalidImageCategoryException(ValueError): @@ -44,13 +66,13 @@ class InvalidImageCategoryException(ValueError): class ImageField(BaseModel): """An image field used for passing image objects between invocations""" - image_type: ImageType = Field( - default=ImageType.RESULT, description="The type of the image" + image_origin: ResourceOrigin = Field( + default=ResourceOrigin.INTERNAL, description="The type of the image" ) image_name: Optional[str] = Field(default=None, description="The name of the image") class Config: - schema_extra = {"required": ["image_type", "image_name"]} + schema_extra = {"required": ["image_origin", "image_name"]} class ColorField(BaseModel): @@ -61,3 +83,11 @@ class ColorField(BaseModel): def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) + + +class ProgressImage(BaseModel): + """The progress image sent intermittently during processing""" + + width: int = Field(description="The effective width of the image in pixels") + height: int = Field(description="The effective height of the image in pixels") + dataURL: str = Field(description="The image data as a b64 data URL") diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index a3e7cdd5dc..788f24dbce 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -1,7 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Optional -from invokeai.app.api.models.images import ProgressImage +from typing import Any +from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 46070b3bf2..68a994ea75 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType from PIL import Image, PngImagePlugin from send2trash import send2trash -from invokeai.app.models.image import ImageType +from invokeai.app.models.image import ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail @@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC): """Low-level service responsible for storing and retrieving image files.""" @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> PILImageType: + def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: """Retrieves an image as PIL Image.""" pass @abstractmethod def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets the internal path to an image or thumbnail.""" pass @@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC): def save( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, @@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC): pass @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: """Deletes an image and its thumbnail (if one exists).""" pass @@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase): Path(output_folder).mkdir(parents=True, exist_ok=True) # TODO: don't hard-code. get/save/delete should maybe take subpath? - for image_type in ImageType: - Path(os.path.join(output_folder, image_type)).mkdir( + for image_origin in ResourceOrigin: + Path(os.path.join(output_folder, image_origin)).mkdir( parents=True, exist_ok=True ) - Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( + Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir( parents=True, exist_ok=True ) - def get(self, image_type: ImageType, image_name: str) -> PILImageType: + def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: try: - image_path = self.get_path(image_type, image_name) + image_path = self.get_path(image_origin, image_name) cache_item = self.__get_cache(image_path) if cache_item: return cache_item @@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase): def save( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, ) -> None: try: - image_path = self.get_path(image_type, image_name) + image_path = self.get_path(image_origin, image_name) if metadata is not None: pnginfo = PngImagePlugin.PngInfo() @@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase): image.save(image_path, "PNG") thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) + thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True) thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image.save(thumbnail_path) @@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase): except Exception as e: raise ImageFileSaveException from e - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: try: basename = os.path.basename(image_name) - image_path = self.get_path(image_type, basename) + image_path = self.get_path(image_origin, basename) if os.path.exists(image_path): send2trash(image_path) @@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase): del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, True) + thumbnail_path = self.get_path(image_origin, thumbnail_name, True) if os.path.exists(thumbnail_path): send2trash(thumbnail_path) @@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase): # TODO: make this a bit more flexible for e.g. cloud storage def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: # strip out any relative path shenanigans basename = os.path.basename(image_name) @@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase): if thumbnail: thumbnail_name = get_thumbnail_name(basename) path = os.path.join( - self.__output_folder, image_type, "thumbnails", thumbnail_name + self.__output_folder, image_origin, "thumbnails", thumbnail_name ) else: - path = os.path.join(self.__output_folder, image_type, basename) + path = os.path.join(self.__output_folder, image_origin, basename) abspath = os.path.abspath(path) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 8afa7000fb..6b6d1ce7b2 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -8,7 +8,7 @@ from typing import Optional, Union from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, ) from invokeai.app.services.models.image_record import ( ImageRecord, @@ -46,7 +46,7 @@ class ImageRecordStorageBase(ABC): # TODO: Implement an `update()` method @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: """Gets an image record.""" pass @@ -54,7 +54,7 @@ class ImageRecordStorageBase(ABC): def update( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, changes: ImageRecordChanges, ) -> None: """Updates an image record.""" @@ -65,10 +65,10 @@ class ImageRecordStorageBase(ABC): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -76,7 +76,7 @@ class ImageRecordStorageBase(ABC): # TODO: The database has a nullable `deleted_at` column, currently unused. # Should we implement soft deletes? Would need coordination with ImageFileStorage. @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: """Deletes an image record.""" pass @@ -84,7 +84,7 @@ class ImageRecordStorageBase(ABC): def save( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, width: int, height: int, @@ -92,7 +92,6 @@ class ImageRecordStorageBase(ABC): node_id: Optional[str], metadata: Optional[ImageMetadata], is_intermediate: bool = False, - show_in_gallery: bool = True, ) -> datetime: """Saves an image record.""" pass @@ -131,7 +130,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): CREATE TABLE IF NOT EXISTS images ( image_name TEXT NOT NULL PRIMARY KEY, -- This is an enum in python, unrestricted string here for flexibility - image_type TEXT NOT NULL, + image_origin TEXT NOT NULL, -- This is an enum in python, unrestricted string here for flexibility image_category TEXT NOT NULL, width INTEGER NOT NULL, @@ -139,7 +138,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id TEXT, node_id TEXT, metadata TEXT, - show_in_gallery BOOLEAN DEFAULT TRUE, is_intermediate BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger @@ -158,7 +156,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): ) self._cursor.execute( """--sql - CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); + CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin); """ ) self._cursor.execute( @@ -185,7 +183,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) - def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: + def get( + self, image_origin: ResourceOrigin, image_name: str + ) -> Union[ImageRecord, None]: try: self._lock.acquire() @@ -212,7 +212,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def update( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, changes: ImageRecordChanges, ) -> None: try: @@ -249,71 +249,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageRecord]: try: self._lock.acquire() # Manually build two queries - one for the count, one for the records - count_query = """--sql - SELECT COUNT(*) FROM images WHERE 1=1 - """ - - images_query = """--sql - SELECT * FROM images WHERE 1=1 - """ + count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" + images_query = f"""SELECT * FROM images WHERE 1=1\n""" query_conditions = "" query_params = [] - if image_type is not None: - query_conditions += """--sql - AND image_type = ? - """ - query_params.append(image_type.value) + if image_origin is not None: + query_conditions += f"""AND image_origin = ?\n""" + query_params.append(image_origin.value) - if image_category is not None: - query_conditions += """--sql - AND image_category = ? - """ - query_params.append(image_category.value) + if include_categories is not None: + ## Convert the enum values to unique list of strings + include_category_strings = list( + map(lambda c: c.value, set(include_categories)) + ) + # Create the correct length of placeholders + placeholders = ",".join("?" * len(include_category_strings)) + query_conditions += f"AND image_category IN ( {placeholders} )\n" + + # Unpack the included categories into the query params + query_params.append(*include_category_strings) + + if exclude_categories is not None: + ## Convert the enum values to unique list of strings + exclude_category_strings = list( + map(lambda c: c.value, set(exclude_categories)) + ) + + # Create the correct length of placeholders + placeholders = ",".join("?" * len(exclude_category_strings)) + query_conditions += f"AND image_category NOT IN ( {placeholders} )\n" + + # Unpack the included categories into the query params + query_params.append(*exclude_category_strings) if is_intermediate is not None: - query_conditions += """--sql - AND is_intermediate = ? - """ + query_conditions += f"""AND is_intermediate = ?\n""" query_params.append(is_intermediate) - if show_in_gallery is not None: - query_conditions += """--sql - AND show_in_gallery = ? - """ - query_params.append(show_in_gallery) - - query_pagination = """--sql - ORDER BY created_at DESC LIMIT ? OFFSET ? - """ - - count_query += query_conditions + ";" - count_params = query_params.copy() + query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" + # Final images query with pagination images_query += query_conditions + query_pagination + ";" + # Add all the parameters images_params = query_params.copy() images_params.append(per_page) images_params.append(page * per_page) - + # Build the list of images, deserializing each row self._cursor.execute(images_query, images_params) - result = cast(list[sqlite3.Row], self._cursor.fetchall()) - images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + # Set up and execute the count query, without pagination + count_query += query_conditions + ";" + count_params = query_params.copy() self._cursor.execute(count_query, count_params) - count = self._cursor.fetchone()[0] except sqlite3.Error as e: self._conn.rollback() @@ -327,7 +328,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): items=images, page=page, pages=pageCount, per_page=per_page, total=count ) - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: try: self._lock.acquire() self._cursor.execute( @@ -347,7 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def save( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, session_id: Optional[str], width: int, @@ -355,7 +356,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id: Optional[str], metadata: Optional[ImageMetadata], is_intermediate: bool = False, - show_in_gallery: bool = True, ) -> datetime: try: metadata_json = ( @@ -366,21 +366,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """--sql INSERT OR IGNORE INTO images ( image_name, - image_type, + image_origin, image_category, width, height, node_id, session_id, metadata, - is_intermediate, - show_in_gallery + is_intermediate ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, - image_type.value, + image_origin.value, image_category.value, width, height, @@ -388,7 +387,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id, metadata_json, is_intermediate, - show_in_gallery, ), ) self._conn.commit() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 1bde1acfd4..dca95f673f 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -5,9 +5,9 @@ from PIL.Image import Image as PILImageType from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, InvalidImageCategoryException, - InvalidImageTypeException, + InvalidOriginException, ) from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.image_record_storage import ( @@ -44,12 +44,11 @@ class ImageServiceABC(ABC): def create( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, intermediate: bool = False, - show_in_gallery: bool = True, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -57,7 +56,7 @@ class ImageServiceABC(ABC): @abstractmethod def update( self, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, changes: ImageRecordChanges, ) -> ImageDTO: @@ -65,22 +64,22 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: """Gets an image as a PIL image.""" pass @abstractmethod - def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: """Gets an image record.""" pass @abstractmethod - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: """Gets an image DTO.""" pass @abstractmethod - def get_path(self, image_type: ImageType, image_name: str) -> str: + def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str: """Gets an image's path.""" pass @@ -91,7 +90,7 @@ class ImageServiceABC(ABC): @abstractmethod def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets an image's or thumbnail's URL.""" pass @@ -101,16 +100,16 @@ class ImageServiceABC(ABC): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @abstractmethod - def delete(self, image_type: ImageType, image_name: str): + def delete(self, image_origin: ResourceOrigin, image_name: str): """Deletes an image.""" pass @@ -171,15 +170,14 @@ class ImageService(ImageServiceABC): def create( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, is_intermediate: bool = False, - show_in_gallery: bool = True, ) -> ImageDTO: - if image_type not in ImageType: - raise InvalidImageTypeException + if image_origin not in ResourceOrigin: + raise InvalidOriginException if image_category not in ImageCategory: raise InvalidImageCategoryException @@ -195,13 +193,12 @@ class ImageService(ImageServiceABC): created_at = self._services.records.save( # Non-nullable fields image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, # Meta fields is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, # Nullable fields node_id=node_id, session_id=session_id, @@ -209,21 +206,21 @@ class ImageService(ImageServiceABC): ) self._services.files.save( - image_type=image_type, + image_origin=image_origin, image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_type, image_name) + image_url = self._services.urls.get_image_url(image_origin, image_name) thumbnail_url = self._services.urls.get_image_url( - image_type, image_name, True + image_origin, image_name, True ) return ImageDTO( # Non-nullable fields image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, @@ -236,7 +233,6 @@ class ImageService(ImageServiceABC): updated_at=created_at, # this is always the same as the created_at at this time deleted_at=None, is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, # Extra non-nullable fields for DTO image_url=image_url, thumbnail_url=thumbnail_url, @@ -253,13 +249,13 @@ class ImageService(ImageServiceABC): def update( self, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.records.update(image_name, image_type, changes) - return self.get_dto(image_type, image_name) + self._services.records.update(image_name, image_origin, changes) + return self.get_dto(image_origin, image_name) except ImageRecordSaveException: self._services.logger.error("Failed to update image record") raise @@ -267,9 +263,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem updating image record") raise e - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: try: - return self._services.files.get(image_type, image_name) + return self._services.files.get(image_origin, image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -277,9 +273,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image file") raise e - def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_type, image_name) + return self._services.records.get(image_origin, image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -287,14 +283,14 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image record") raise e - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_type, image_name) + image_record = self._services.records.get(image_origin, image_name) image_dto = image_record_to_dto( image_record, - self._services.urls.get_image_url(image_type, image_name), - self._services.urls.get_image_url(image_type, image_name, True), + self._services.urls.get_image_url(image_origin, image_name), + self._services.urls.get_image_url(image_origin, image_name, True), ) return image_dto @@ -306,10 +302,10 @@ class ImageService(ImageServiceABC): raise e def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: try: - return self._services.files.get_path(image_type, image_name, thumbnail) + return self._services.files.get_path(image_origin, image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -322,10 +318,10 @@ class ImageService(ImageServiceABC): raise e def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: try: - return self._services.urls.get_image_url(image_type, image_name, thumbnail) + return self._services.urls.get_image_url(image_origin, image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -334,28 +330,28 @@ class ImageService(ImageServiceABC): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageDTO]: try: results = self._services.records.get_many( page, per_page, - image_type, - image_category, + image_origin, + include_categories, + exclude_categories, is_intermediate, - show_in_gallery, ) image_dtos = list( map( lambda r: image_record_to_dto( r, - self._services.urls.get_image_url(r.image_type, r.image_name), + self._services.urls.get_image_url(r.image_origin, r.image_name), self._services.urls.get_image_url( - r.image_type, r.image_name, True + r.image_origin, r.image_name, True ), ), results.items, @@ -373,10 +369,10 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting paginated image DTOs") raise e - def delete(self, image_type: ImageType, image_name: str): + def delete(self, image_origin: ResourceOrigin, image_name: str): try: - self._services.files.delete(image_type, image_name) - self._services.records.delete(image_type, image_name) + self._services.files.delete(image_origin, image_name) + self._services.records.delete(image_origin, image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index faa6e1b41a..f143a30928 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,7 +1,7 @@ import datetime from typing import Optional, Union from pydantic import BaseModel, Extra, Field, StrictStr -from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -11,8 +11,8 @@ class ImageRecord(BaseModel): image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_type: ImageType = Field(description="The type of the image.") - """The type of the image.""" + image_origin: ResourceOrigin = Field(description="The type of the image.") + """The origin of the image.""" image_category: ImageCategory = Field(description="The category of the image.") """The category of the image.""" width: int = Field(description="The width of the image in px.") @@ -33,8 +33,6 @@ class ImageRecord(BaseModel): """The deleted timestamp of the image.""" is_intermediate: bool = Field(description="Whether this is an intermediate image.") """Whether this is an intermediate image.""" - show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.") - """Whether this image should be shown in the gallery.""" session_id: Optional[str] = Field( default=None, description="The session ID that generated this image, if it is a generated image.", @@ -76,8 +74,8 @@ class ImageUrlsDTO(BaseModel): image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_type: ImageType = Field(description="The type of the image.") - """The type of the image.""" + image_origin: ResourceOrigin = Field(description="The type of the image.") + """The origin of the image.""" image_url: str = Field(description="The URL of the image.") """The URL of the image.""" thumbnail_url: str = Field(description="The URL of the image's thumbnail.") @@ -107,7 +105,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # Retrieve all the values, setting "reasonable" defaults if they are not present. image_name = image_dict.get("image_name", "unknown") - image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) + image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)) image_category = ImageCategory( image_dict.get("image_category", ImageCategory.GENERAL.value) ) @@ -119,7 +117,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) - show_in_gallery = image_dict.get("show_in_gallery", True) raw_metadata = image_dict.get("metadata") @@ -130,7 +127,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: return ImageRecord( image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, @@ -141,5 +138,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at=updated_at, deleted_at=deleted_at, is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, ) diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 2716da60ad..4c8354c899 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod -from invokeai.app.models.image import ImageType +from invokeai.app.models.image import ResourceOrigin from invokeai.app.util.thumbnails import get_thumbnail_name @@ -10,7 +10,7 @@ class UrlServiceBase(ABC): @abstractmethod def get_image_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets the URL for an image or thumbnail.""" pass @@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase): self._base_url = base_url def get_image_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: image_basename = os.path.basename(image_name) # These paths are determined by the routes in invokeai/app/api/routers/images.py if thumbnail: return ( - f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" + f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail" ) - return f"{self._base_url}/images/{image_type.value}/{image_basename}" + return f"{self._base_url}/images/{image_origin.value}/{image_basename}" diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 963e770406..b4b9a25909 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,5 +1,5 @@ -from invokeai.app.api.models.images import ProgressImage from invokeai.app.models.exceptions import CanceledException +from invokeai.app.models.image import ProgressImage from ..invocations.baseinvocation import InvocationContext from ...backend.util.util import image_to_dataURL from ...backend.generator.base import Generator From d78e3572e3e1feb29a42106d955434b8a43c4a7a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 21:40:02 +1000 Subject: [PATCH 10/34] chore(ui): regen api client --- .../frontend/web/src/services/api/index.ts | 2 +- .../src/services/api/models/ImageCategory.ts | 10 +- .../web/src/services/api/models/ImageDTO.ts | 8 +- .../web/src/services/api/models/ImageField.ts | 4 +- .../web/src/services/api/models/ImageType.ts | 8 -- .../src/services/api/models/ImageUrlsDTO.ts | 4 +- .../src/services/api/models/ResourceOrigin.ts | 12 +++ .../services/api/services/ImagesService.ts | 94 +++++++++---------- 8 files changed, 71 insertions(+), 71 deletions(-) delete mode 100644 invokeai/frontend/web/src/services/api/models/ImageType.ts create mode 100644 invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index e75aeac6cb..d9f00becd9 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -42,7 +42,6 @@ export type { ImagePasteInvocation } from './models/ImagePasteInvocation'; export type { ImageRecordChanges } from './models/ImageRecordChanges'; export type { ImageToImageInvocation } from './models/ImageToImageInvocation'; export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation'; -export type { ImageType } from './models/ImageType'; export type { ImageUrlsDTO } from './models/ImageUrlsDTO'; export type { InfillColorInvocation } from './models/InfillColorInvocation'; export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation'; @@ -72,6 +71,7 @@ export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; export type { RangeInvocation } from './models/RangeInvocation'; export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation'; export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation'; +export type { ResourceOrigin } from './models/ResourceOrigin'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation'; diff --git a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts index 6b04a0b864..84551d3cd6 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts @@ -3,6 +3,12 @@ /* eslint-disable */ /** - * The category of an image. Use ImageCategory.OTHER for non-default categories. + * The category of an image. + * + * - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose. + * - MASK: The image is a mask image. + * - CONTROL: The image is a ControlNet control image. + * - USER: The image is a user-provide image. + * - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes. */ -export type ImageCategory = 'general' | 'control' | 'mask' | 'other'; +export type ImageCategory = 'general' | 'mask' | 'control' | 'user' | 'other'; diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts index 599ca51de4..f5f2603b03 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts @@ -4,7 +4,7 @@ import type { ImageCategory } from './ImageCategory'; import type { ImageMetadata } from './ImageMetadata'; -import type { ImageType } from './ImageType'; +import type { ResourceOrigin } from './ResourceOrigin'; /** * Deserialized image record, enriched for the frontend with URLs. @@ -17,7 +17,7 @@ export type ImageDTO = { /** * The type of the image. */ - image_type: ImageType; + image_origin: ResourceOrigin; /** * The URL of the image. */ @@ -54,10 +54,6 @@ export type ImageDTO = { * Whether this is an intermediate image. */ is_intermediate: boolean; - /** - * Whether this image should be shown in the gallery. - */ - show_in_gallery: boolean; /** * The session ID that generated this image, if it is a generated image. */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageField.ts b/invokeai/frontend/web/src/services/api/models/ImageField.ts index fa22ae8007..63a12f4730 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageField.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageField.ts @@ -2,7 +2,7 @@ /* tslint:disable */ /* eslint-disable */ -import type { ImageType } from './ImageType'; +import type { ResourceOrigin } from './ResourceOrigin'; /** * An image field used for passing image objects between invocations @@ -11,7 +11,7 @@ export type ImageField = { /** * The type of the image */ - image_type: ImageType; + image_origin: ResourceOrigin; /** * The name of the image */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageType.ts b/invokeai/frontend/web/src/services/api/models/ImageType.ts deleted file mode 100644 index dfc10bf455..0000000000 --- a/invokeai/frontend/web/src/services/api/models/ImageType.ts +++ /dev/null @@ -1,8 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -/** - * The type of an image. - */ -export type ImageType = 'results' | 'uploads'; diff --git a/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts index af80519ef2..81639be9b3 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts @@ -2,7 +2,7 @@ /* tslint:disable */ /* eslint-disable */ -import type { ImageType } from './ImageType'; +import type { ResourceOrigin } from './ResourceOrigin'; /** * The URLs for an image and its thumbnail. @@ -15,7 +15,7 @@ export type ImageUrlsDTO = { /** * The type of the image. */ - image_type: ImageType; + image_origin: ResourceOrigin; /** * The URL of the image. */ diff --git a/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts b/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts new file mode 100644 index 0000000000..a82edda0c1 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts @@ -0,0 +1,12 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * The origin of a resource (eg image). + * + * - INTERNAL: The resource was created by the application. + * - EXTERNAL: The resource was not created by the application. + * This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + */ +export type ResourceOrigin = 'internal' | 'external'; diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index a8d22e802b..379f9f7dd7 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -5,9 +5,9 @@ import type { Body_upload_image } from '../models/Body_upload_image'; import type { ImageCategory } from '../models/ImageCategory'; import type { ImageDTO } from '../models/ImageDTO'; import type { ImageRecordChanges } from '../models/ImageRecordChanges'; -import type { ImageType } from '../models/ImageType'; import type { ImageUrlsDTO } from '../models/ImageUrlsDTO'; import type { PaginatedResults_ImageDTO_ } from '../models/PaginatedResults_ImageDTO_'; +import type { ResourceOrigin } from '../models/ResourceOrigin'; import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; @@ -22,29 +22,29 @@ export class ImagesService { * @throws ApiError */ public static listImagesWithMetadata({ - imageType, - imageCategory, + imageOrigin, + includeCategories, + excludeCategories, isIntermediate, - showInGallery, page, perPage = 10, }: { /** - * The type of images to list + * The origin of images to list */ - imageType?: ImageType, + imageOrigin?: ResourceOrigin, /** - * The kind of images to list + * The categories of image to include */ - imageCategory?: ImageCategory, + includeCategories?: Array, + /** + * The categories of image to exclude + */ + excludeCategories?: Array, /** * Whether to list intermediate images */ isIntermediate?: boolean, - /** - * Whether to list images that show in the gallery - */ - showInGallery?: boolean, /** * The page of images to get */ @@ -58,10 +58,10 @@ export class ImagesService { method: 'GET', url: '/api/v1/images/', query: { - 'image_type': imageType, - 'image_category': imageCategory, + 'image_origin': imageOrigin, + 'include_categories': includeCategories, + 'exclude_categories': excludeCategories, 'is_intermediate': isIntermediate, - 'show_in_gallery': showInGallery, 'page': page, 'per_page': perPage, }, @@ -80,7 +80,6 @@ export class ImagesService { public static uploadImage({ imageCategory, isIntermediate, - showInGallery, formData, sessionId, }: { @@ -92,10 +91,6 @@ export class ImagesService { * Whether this is an intermediate image */ isIntermediate: boolean, - /** - * Whether this image should be shown in the gallery - */ - showInGallery: boolean, formData: Body_upload_image, /** * The session ID associated with this upload, if any @@ -108,7 +103,6 @@ export class ImagesService { query: { 'image_category': imageCategory, 'is_intermediate': isIntermediate, - 'show_in_gallery': showInGallery, 'session_id': sessionId, }, formData: formData, @@ -127,13 +121,13 @@ export class ImagesService { * @throws ApiError */ public static getImageFull({ - imageType, + imageOrigin, imageName, }: { /** * The type of full-resolution image file to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of full-resolution image file to get */ @@ -141,9 +135,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}', + url: '/api/v1/images/{image_origin}/{image_name}', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -160,13 +154,13 @@ export class ImagesService { * @throws ApiError */ public static deleteImage({ - imageType, + imageOrigin, imageName, }: { /** - * The type of image to delete + * The origin of image to delete */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of the image to delete */ @@ -174,9 +168,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'DELETE', - url: '/api/v1/images/{image_type}/{image_name}', + url: '/api/v1/images/{image_origin}/{image_name}', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -192,14 +186,14 @@ export class ImagesService { * @throws ApiError */ public static updateImage({ - imageType, + imageOrigin, imageName, requestBody, }: { /** - * The type of image to update + * The origin of image to update */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of the image to update */ @@ -208,9 +202,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'PATCH', - url: '/api/v1/images/{image_type}/{image_name}', + url: '/api/v1/images/{image_origin}/{image_name}', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, body: requestBody, @@ -228,13 +222,13 @@ export class ImagesService { * @throws ApiError */ public static getImageMetadata({ - imageType, + imageOrigin, imageName, }: { /** - * The type of image to get + * The origin of image to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of image to get */ @@ -242,9 +236,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}/metadata', + url: '/api/v1/images/{image_origin}/{image_name}/metadata', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -260,13 +254,13 @@ export class ImagesService { * @throws ApiError */ public static getImageThumbnail({ - imageType, + imageOrigin, imageName, }: { /** - * The type of thumbnail image file to get + * The origin of thumbnail image file to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of thumbnail image file to get */ @@ -274,9 +268,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}/thumbnail', + url: '/api/v1/images/{image_origin}/{image_name}/thumbnail', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -293,13 +287,13 @@ export class ImagesService { * @throws ApiError */ public static getImageUrls({ - imageType, + imageOrigin, imageName, }: { /** - * The type of the image whose URL to get + * The origin of the image whose URL to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of the image whose URL to get */ @@ -307,9 +301,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}/urls', + url: '/api/v1/images/{image_origin}/{image_name}/urls', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { From 29fcc92da9e235ba1c87a2e04e69f457cdce839b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 21:46:03 +1000 Subject: [PATCH 11/34] feat(ui): handle new image origin/category setup - Update all thunks & network related things - Update gallery What I have not done yet is rename the gallery tabs and the relevant slices, but I believe the functionality is all there. Also I fixed several bugs along the way but couldn't really commit them separately bc I was refactoring. Can't remember what they were, but related to the gallery image switching. --- .../middleware/listenerMiddleware/index.ts | 9 + .../listeners/canvasMerged.ts | 1 - .../listeners/canvasSavedToGallery.ts | 3 +- .../listeners/imageDeleted.ts | 41 ++- .../listeners/imageMetadataReceived.ts | 23 +- .../listeners/imageUpdated.ts | 26 ++ .../listeners/imageUploaded.ts | 27 +- .../listeners/imageUrlsReceived.ts | 6 +- .../listeners/initialImageSelected.ts | 6 +- .../listeners/socketio/invocationComplete.ts | 8 +- .../listeners/userInvokedCanvas.ts | 10 +- .../frontend/web/src/app/types/invokeai.ts | 4 +- .../src/common/components/ImageUploader.tsx | 3 +- .../web/src/common/util/parseMetadata.ts | 239 ------------------ .../components/CurrentImagePreview.tsx | 2 +- .../gallery/components/HoverableImage.tsx | 2 +- .../gallery/hooks/useGetImageByName.ts | 12 +- .../web/src/features/gallery/store/actions.ts | 4 +- .../features/gallery/store/gallerySlice.ts | 17 ++ .../features/gallery/store/resultsSlice.ts | 9 +- .../features/gallery/store/uploadsSlice.ts | 13 +- .../fields/ImageInputFieldComponent.tsx | 10 +- .../graphBuilders/buildImageToImageGraph.ts | 2 +- .../nodeBuilders/buildImageToImageNode.ts | 2 +- .../util/nodeBuilders/buildInpaintNode.ts | 2 +- .../ImageToImage/InitialImagePreview.tsx | 8 +- .../src/features/parameters/store/actions.ts | 12 +- .../web/src/services/thunks/gallery.ts | 5 +- .../frontend/web/src/services/types/guards.ts | 20 +- 29 files changed, 181 insertions(+), 345 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts delete mode 100644 invokeai/frontend/web/src/common/util/parseMetadata.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index b669becfe6..7159957efa 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -67,6 +67,10 @@ import { addReceivedUploadImagesPageFulfilledListener, addReceivedUploadImagesPageRejectedListener, } from './listeners/receivedUploadImages'; +import { + addImageUpdatedFulfilledListener, + addImageUpdatedRejectedListener, +} from './listeners/imageUpdated'; export const listenerMiddleware = createListenerMiddleware(); @@ -90,6 +94,11 @@ export type AppListenerEffect = ListenerEffect< addImageUploadedFulfilledListener(); addImageUploadedRejectedListener(); +// Image updated +addImageUpdatedFulfilledListener(); +addImageUpdatedRejectedListener(); + +// Image selected addInitialImageSelectedListener(); // Image deleted diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts index fc4c7247cd..80865f3126 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -57,7 +57,6 @@ export const addCanvasMergedListener = () => { }, imageCategory: 'general', isIntermediate: true, - showInGallery: false, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index 7656e58b57..01f097cdd1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -37,8 +37,7 @@ export const addCanvasSavedToGalleryListener = () => { file: new File([blob], filename, { type: 'image/png' }), }, imageCategory: 'general', - isIntermediate: false, - showInGallery: true, + isIntermediate: true, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index cd4771b96a..7bd92e7e13 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -4,8 +4,15 @@ import { imageDeleted } from 'services/thunks/image'; import { log } from 'app/logging/useLogger'; import { clamp } from 'lodash-es'; import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { uploadsAdapter } from 'features/gallery/store/uploadsSlice'; -import { resultsAdapter } from 'features/gallery/store/resultsSlice'; +import { + uploadRemoved, + uploadsAdapter, +} from 'features/gallery/store/uploadsSlice'; +import { + resultRemoved, + resultsAdapter, +} from 'features/gallery/store/resultsSlice'; +import { isUploadsImageDTO } from 'services/types/guards'; const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); @@ -22,13 +29,17 @@ export const addRequestedImageDeletionListener = () => { return; } - const { image_name, image_type } = image; + const { image_name, image_origin } = image; - const selectedImageName = getState().gallery.selectedImage?.image_name; + const state = getState(); + const selectedImage = state.gallery.selectedImage; + const isUserImage = isUploadsImageDTO(selectedImage); + if (selectedImage && selectedImage.image_name === image_name) { + const allIds = isUserImage ? state.uploads.ids : state.results.ids; - if (selectedImageName === image_name) { - const allIds = getState()[image_type].ids; - const allEntities = getState()[image_type].entities; + const allEntities = isUserImage + ? state.uploads.entities + : state.results.entities; const deletedImageIndex = allIds.findIndex( (result) => result.toString() === image_name @@ -53,7 +64,15 @@ export const addRequestedImageDeletionListener = () => { } } - dispatch(imageDeleted({ imageName: image_name, imageType: image_type })); + if (isUserImage) { + dispatch(uploadRemoved(image_name)); + } else { + dispatch(resultRemoved(image_name)); + } + + dispatch( + imageDeleted({ imageName: image_name, imageOrigin: image_origin }) + ); }, }); }; @@ -65,12 +84,12 @@ export const addImageDeletedPendingListener = () => { startAppListening({ actionCreator: imageDeleted.pending, effect: (action, { dispatch, getState }) => { - const { imageName, imageType } = action.meta.arg; + const { imageName, imageOrigin } = action.meta.arg; // Preemptively remove the image from the gallery - if (imageType === 'uploads') { + if (imageOrigin === 'external') { uploadsAdapter.removeOne(getState().uploads, imageName); } - if (imageType === 'results') { + if (imageOrigin === 'internal') { resultsAdapter.removeOne(getState().results, imageName); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts index c93ed2820f..276ef7be6c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts @@ -1,14 +1,9 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; import { imageMetadataReceived } from 'services/thunks/image'; -import { - ResultsImageDTO, - resultUpserted, -} from 'features/gallery/store/resultsSlice'; -import { - UploadsImageDTO, - uploadUpserted, -} from 'features/gallery/store/uploadsSlice'; +import { resultUpserted } from 'features/gallery/store/resultsSlice'; +import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'image' }); @@ -16,15 +11,15 @@ export const addImageMetadataReceivedFulfilledListener = () => { startAppListening({ actionCreator: imageMetadataReceived.fulfilled, effect: (action, { getState, dispatch }) => { - const image = action.payload; - moduleLog.debug({ data: { image } }, 'Image metadata received'); + const imageDTO = action.payload; + moduleLog.debug({ data: { imageDTO } }, 'Image metadata received'); - if (image.image_type === 'results') { - dispatch(resultUpserted(action.payload as ResultsImageDTO)); + if (imageDTO.image_origin === 'internal') { + dispatch(resultUpserted(imageDTO)); } - if (image.image_type === 'uploads') { - dispatch(uploadUpserted(action.payload as UploadsImageDTO)); + if (imageDTO.image_origin === 'external') { + dispatch(uploadUpserted(imageDTO)); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts new file mode 100644 index 0000000000..6f8b46ec23 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts @@ -0,0 +1,26 @@ +import { startAppListening } from '..'; +import { imageUpdated } from 'services/thunks/image'; +import { log } from 'app/logging/useLogger'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageUpdatedFulfilledListener = () => { + startAppListening({ + actionCreator: imageUpdated.fulfilled, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + { oldImage: action.meta.arg, updatedImage: action.payload }, + 'Image updated' + ); + }, + }); +}; + +export const addImageUpdatedRejectedListener = () => { + startAppListening({ + actionCreator: imageUpdated.rejected, + effect: (action, { dispatch }) => { + moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed'); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 3d69eb8f9a..dcce86017e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,6 +1,9 @@ import { startAppListening } from '..'; import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { + imageSelected, + setCurrentCategory, +} from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; import { addToast } from 'features/system/store/systemSlice'; import { resultUpserted } from 'features/gallery/store/resultsSlice'; @@ -10,31 +13,30 @@ const moduleLog = log.child({ namespace: 'image' }); export const addImageUploadedFulfilledListener = () => { startAppListening({ - predicate: (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.payload.is_intermediate === false, + actionCreator: imageUploaded.fulfilled, effect: (action, { dispatch, getState }) => { const image = action.payload; moduleLog.debug({ arg: '', image }, 'Image uploaded'); + if (action.payload.is_intermediate) { + // No further actions needed for intermediate images + return; + } + const state = getState(); // Handle uploads - if (!image.show_in_gallery && image.image_type === 'uploads') { + if (image.image_category === 'user' && !image.is_intermediate) { dispatch(uploadUpserted(image)); - dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); - - if (state.gallery.shouldAutoSwitchToNewImages) { - dispatch(imageSelected(image)); - } } // Handle results // TODO: Can this ever happen? I don't think so... - if (image.show_in_gallery) { + if (image.image_category !== 'user' && !image.is_intermediate) { dispatch(resultUpserted(image)); + dispatch(setCurrentCategory('results')); } }, }); @@ -44,6 +46,9 @@ export const addImageUploadedRejectedListener = () => { startAppListening({ actionCreator: imageUploaded.rejected, effect: (action, { dispatch }) => { + const { formData, ...rest } = action.meta.arg; + const sanitizedData = { arg: { ...rest, formData: { file: '' } } }; + moduleLog.error({ data: sanitizedData }, 'Image upload failed'); dispatch( addToast({ title: 'Image Upload Failed', diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts index 4ff2a02118..588d7611cc 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts @@ -13,9 +13,9 @@ export const addImageUrlsReceivedFulfilledListener = () => { const image = action.payload; moduleLog.debug({ data: { image } }, 'Image URLs received'); - const { image_type, image_name, image_url, thumbnail_url } = image; + const { image_origin, image_name, image_url, thumbnail_url } = image; - if (image_type === 'results') { + if (image_origin === 'results') { resultsAdapter.updateOne(getState().results, { id: image_name, changes: { @@ -25,7 +25,7 @@ export const addImageUrlsReceivedFulfilledListener = () => { }); } - if (image_type === 'uploads') { + if (image_origin === 'uploads') { uploadsAdapter.updateOne(getState().uploads, { id: image_name, changes: { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts index d6cfc260f3..a2e783a38a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts @@ -30,14 +30,14 @@ export const addInitialImageSelectedListener = () => { return; } - const { image_name, image_type } = action.payload; + const { image_name, image_origin } = action.payload; let image: ImageDTO | undefined; const state = getState(); - if (image_type === 'results') { + if (image_origin === 'results') { image = selectResultsById(state, image_name); - } else if (image_type === 'uploads') { + } else if (image_origin === 'uploads') { image = selectUploadsById(state, image_name); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts index 60a3cdfedf..81c0286e3b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts @@ -34,13 +34,13 @@ export const addInvocationCompleteListener = () => { // This complete event has an associated image output if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { - const { image_name, image_type } = result.image; + const { image_name, image_origin } = result.image; // Get its metadata dispatch( imageMetadataReceived({ imageName: image_name, - imageType: image_type, + imageOrigin: image_origin, }) ); @@ -48,10 +48,6 @@ export const addInvocationCompleteListener = () => { imageMetadataReceived.fulfilled.match ); - if (getState().gallery.shouldAutoSwitchToNewImages) { - dispatch(imageSelected(imageDTO)); - } - // Handle canvas image if ( graph_execution_state_id === diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index bc1d5d5f8a..0ee3016bdb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -103,7 +103,6 @@ export const addUserInvokedCanvasListener = () => { }, imageCategory: 'general', isIntermediate: true, - showInGallery: false, }) ); @@ -117,7 +116,7 @@ export const addUserInvokedCanvasListener = () => { // Update the base node with the image name and type baseNode.image = { image_name: baseImageDTO.image_name, - image_type: baseImageDTO.image_type, + image_origin: baseImageDTO.image_origin, }; } @@ -131,7 +130,6 @@ export const addUserInvokedCanvasListener = () => { }, imageCategory: 'mask', isIntermediate: true, - showInGallery: false, }) ); @@ -145,7 +143,7 @@ export const addUserInvokedCanvasListener = () => { // Update the base node with the image name and type baseNode.mask = { image_name: maskImageDTO.image_name, - image_type: maskImageDTO.image_type, + image_origin: maskImageDTO.image_origin, }; } @@ -162,7 +160,7 @@ export const addUserInvokedCanvasListener = () => { dispatch( imageUpdated({ imageName: baseNode.image.image_name, - imageType: baseNode.image.image_type, + imageOrigin: baseNode.image.image_origin, requestBody: { session_id: sessionId }, }) ); @@ -173,7 +171,7 @@ export const addUserInvokedCanvasListener = () => { dispatch( imageUpdated({ imageName: baseNode.mask.image_name, - imageType: baseNode.mask.image_type, + imageOrigin: baseNode.mask.image_origin, requestBody: { session_id: sessionId }, }) ); diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 68f7568779..0de1d8c84b 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -15,7 +15,7 @@ import { SelectedImage } from 'features/parameters/store/actions'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { IRect } from 'konva/lib/types'; -import { ImageResponseMetadata, ImageType } from 'services/api'; +import { ImageResponseMetadata, ResourceOrigin } from 'services/api'; import { O } from 'ts-toolbelt'; /** @@ -124,7 +124,7 @@ export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata; */ // export ty`pe Image = { // name: string; -// type: ImageType; +// type: image_origin; // url: string; // thumbnail: string; // metadata: ImageResponseMetadata; diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx index a4e6e52cb8..17f6d68633 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx @@ -69,9 +69,8 @@ const ImageUploader = (props: ImageUploaderProps) => { dispatch( imageUploaded({ formData: { file }, - imageCategory: 'general', + imageCategory: 'user', isIntermediate: false, - showInGallery: false, }) ); }, diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts deleted file mode 100644 index bb3999d6d0..0000000000 --- a/invokeai/frontend/web/src/common/util/parseMetadata.ts +++ /dev/null @@ -1,239 +0,0 @@ -import { forEach, size } from 'lodash-es'; -import { - ImageField, - LatentsField, - ConditioningField, - ControlField, -} from 'services/api'; - -const OBJECT_TYPESTRING = '[object Object]'; -const STRING_TYPESTRING = '[object String]'; -const NUMBER_TYPESTRING = '[object Number]'; -const BOOLEAN_TYPESTRING = '[object Boolean]'; -const ARRAY_TYPESTRING = '[object Array]'; - -const isObject = (obj: unknown): obj is Record => - Object.prototype.toString.call(obj) === OBJECT_TYPESTRING; - -const isString = (obj: unknown): obj is string => - Object.prototype.toString.call(obj) === STRING_TYPESTRING; - -const isNumber = (obj: unknown): obj is number => - Object.prototype.toString.call(obj) === NUMBER_TYPESTRING; - -const isBoolean = (obj: unknown): obj is boolean => - Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING; - -const isArray = (obj: unknown): obj is Array => - Object.prototype.toString.call(obj) === ARRAY_TYPESTRING; - -const parseImageField = (imageField: unknown): ImageField | undefined => { - // Must be an object - if (!isObject(imageField)) { - return; - } - - // An ImageField must have both `image_name` and `image_type` - if (!('image_name' in imageField && 'image_type' in imageField)) { - return; - } - - // An ImageField's `image_type` must be one of the allowed values - if ( - !['results', 'uploads', 'intermediates'].includes(imageField.image_type) - ) { - return; - } - - // An ImageField's `image_name` must be a string - if (typeof imageField.image_name !== 'string') { - return; - } - - // Build a valid ImageField - return { - image_type: imageField.image_type, - image_name: imageField.image_name, - }; -}; - -const parseLatentsField = (latentsField: unknown): LatentsField | undefined => { - // Must be an object - if (!isObject(latentsField)) { - return; - } - - // A LatentsField must have a `latents_name` - if (!('latents_name' in latentsField)) { - return; - } - - // A LatentsField's `latents_name` must be a string - if (typeof latentsField.latents_name !== 'string') { - return; - } - - // Build a valid LatentsField - return { - latents_name: latentsField.latents_name, - }; -}; - -const parseConditioningField = ( - conditioningField: unknown -): ConditioningField | undefined => { - // Must be an object - if (!isObject(conditioningField)) { - return; - } - - // A ConditioningField must have a `conditioning_name` - if (!('conditioning_name' in conditioningField)) { - return; - } - - // A ConditioningField's `conditioning_name` must be a string - if (typeof conditioningField.conditioning_name !== 'string') { - return; - } - - // Build a valid ConditioningField - return { - conditioning_name: conditioningField.conditioning_name, - }; -}; - -const parseControlField = (controlField: unknown): ControlField | undefined => { - // Must be an object - if (!isObject(controlField)) { - return; - } - - // A ControlField must have a `control` - if (!('control' in controlField)) { - return; - } - // console.log(typeof controlField.control); - - // Build a valid ControlField - return { - control: controlField.control, - }; -}; - -type NodeMetadata = { - [key: string]: - | string - | number - | boolean - | ImageField - | LatentsField - | ConditioningField - | ControlField; -}; - -type InvokeAIMetadata = { - session_id?: string; - node?: NodeMetadata; -}; - -export const parseNodeMetadata = ( - nodeMetadata: Record -): NodeMetadata | undefined => { - if (!isObject(nodeMetadata)) { - return; - } - - const parsed: NodeMetadata = {}; - - forEach(nodeMetadata, (nodeItem, nodeKey) => { - // `id` and `type` must be strings if they are present - if (['id', 'type'].includes(nodeKey)) { - if (isString(nodeItem)) { - parsed[nodeKey] = nodeItem; - } - return; - } - - // the only valid object types are ImageField, LatentsField, ConditioningField, ControlField - if (isObject(nodeItem)) { - if ('image_name' in nodeItem || 'image_type' in nodeItem) { - const imageField = parseImageField(nodeItem); - if (imageField) { - parsed[nodeKey] = imageField; - } - return; - } - - if ('latents_name' in nodeItem) { - const latentsField = parseLatentsField(nodeItem); - if (latentsField) { - parsed[nodeKey] = latentsField; - } - return; - } - - if ('conditioning_name' in nodeItem) { - const conditioningField = parseConditioningField(nodeItem); - if (conditioningField) { - parsed[nodeKey] = conditioningField; - } - return; - } - - if ('control' in nodeItem) { - const controlField = parseControlField(nodeItem); - if (controlField) { - parsed[nodeKey] = controlField; - } - return; - } - } - - // otherwise we accept any string, number or boolean - if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) { - parsed[nodeKey] = nodeItem; - return; - } - }); - - if (size(parsed) === 0) { - return; - } - - return parsed; -}; - -export const parseInvokeAIMetadata = ( - metadata: Record | undefined -): InvokeAIMetadata | undefined => { - if (metadata === undefined) { - return; - } - - if (!isObject(metadata)) { - return; - } - - const parsed: InvokeAIMetadata = {}; - - forEach(metadata, (item, key) => { - if (key === 'session_id' && isString(item)) { - parsed['session_id'] = item; - } - - if (key === 'node' && isObject(item)) { - const nodeMetadata = parseNodeMetadata(item); - - if (nodeMetadata) { - parsed['node'] = nodeMetadata; - } - } - }); - - if (size(parsed) === 0) { - return; - } - - return parsed; -}; diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index 4562e3458d..38c104a83d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -62,7 +62,7 @@ const CurrentImagePreview = () => { return; } e.dataTransfer.setData('invokeai/imageName', image.image_name); - e.dataTransfer.setData('invokeai/imageType', image.image_type); + e.dataTransfer.setData('invokeai/imageOrigin', image.image_origin); e.dataTransfer.effectAllowed = 'move'; }, [image] diff --git a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx index ed427f4984..4a51580650 100644 --- a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx @@ -147,7 +147,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { const handleDragStart = useCallback( (e: DragEvent) => { e.dataTransfer.setData('invokeai/imageName', image.image_name); - e.dataTransfer.setData('invokeai/imageType', image.image_type); + e.dataTransfer.setData('invokeai/imageOrigin', image.image_origin); e.dataTransfer.effectAllowed = 'move'; }, [image] diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts index ad0870e7a4..1a73971774 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts @@ -1,6 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { ImageType } from 'services/api'; +import { ResourceOrigin } from 'services/api'; import { selectResultsEntities } from '../store/resultsSlice'; import { selectUploadsEntities } from '../store/uploadsSlice'; @@ -11,17 +11,17 @@ const useGetImageByNameSelector = createSelector( } ); -const useGetImageByNameAndType = () => { +const useGetImageByNameAndOrigin = () => { const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector); - return (name: string, type: ImageType) => { - if (type === 'results') { + return (name: string, origin: ResourceOrigin) => { + if (origin === 'internal') { const resultImagesResult = allResults[name]; if (resultImagesResult) { return resultImagesResult; } } - if (type === 'uploads') { + if (origin === 'external') { const userImagesResult = allUploads[name]; if (userImagesResult) { return userImagesResult; @@ -30,4 +30,4 @@ const useGetImageByNameAndType = () => { }; }; -export default useGetImageByNameAndType; +export default useGetImageByNameAndOrigin; diff --git a/invokeai/frontend/web/src/features/gallery/store/actions.ts b/invokeai/frontend/web/src/features/gallery/store/actions.ts index 7e071f279d..7c00201da9 100644 --- a/invokeai/frontend/web/src/features/gallery/store/actions.ts +++ b/invokeai/frontend/web/src/features/gallery/store/actions.ts @@ -1,9 +1,9 @@ import { createAction } from '@reduxjs/toolkit'; -import { ImageNameAndType } from 'features/parameters/store/actions'; +import { ImageNameAndOrigin } from 'features/parameters/store/actions'; import { ImageDTO } from 'services/api'; export const requestedImageDeletion = createAction< - ImageDTO | ImageNameAndType | undefined + ImageDTO | ImageNameAndOrigin | undefined >('gallery/requestedImageDeletion'); export const sentImageToCanvas = createAction('gallery/sentImageToCanvas'); diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 1a49aeac1e..e904620d90 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -5,6 +5,8 @@ import { receivedUploadImages, } from '../../../services/thunks/gallery'; import { ImageDTO } from 'services/api'; +import { resultUpserted } from './resultsSlice'; +import { uploadUpserted } from './uploadsSlice'; type GalleryImageObjectFitType = 'contain' | 'cover'; @@ -76,6 +78,7 @@ export const gallerySlice = createSlice({ } } }); + builder.addCase(receivedUploadImages.fulfilled, (state, action) => { // rehydrate selectedImage URL when results list comes in // solves case when outdated URL is in local storage @@ -92,6 +95,20 @@ export const gallerySlice = createSlice({ } } }); + + builder.addCase(resultUpserted, (state, action) => { + if (state.shouldAutoSwitchToNewImages) { + state.selectedImage = action.payload; + state.currentCategory = 'results'; + } + }); + + builder.addCase(uploadUpserted, (state, action) => { + if (state.shouldAutoSwitchToNewImages) { + state.selectedImage = action.payload; + state.currentCategory = 'uploads'; + } + }); }, }); diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts index ad05284119..5bc7bd14dd 100644 --- a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts @@ -11,8 +11,8 @@ import { import { ImageDTO } from 'services/api'; import { dateComparator } from 'common/util/dateComparator'; -export type ResultsImageDTO = Omit & { - image_type: 'results'; +export type ResultsImageDTO = Omit & { + image_origin: 'results'; }; export const resultsAdapter = createEntityAdapter({ @@ -47,6 +47,9 @@ const resultsSlice = createSlice({ resultsAdapter.upsertOne(state, action.payload); state.upsertedImageCount += 1; }, + resultRemoved: (state, action: PayloadAction) => { + resultsAdapter.removeOne(state, action.payload); + }, }, extraReducers: (builder) => { /** @@ -83,6 +86,6 @@ export const { selectTotal: selectResultsTotal, } = resultsAdapter.getSelectors((state) => state.results); -export const { resultUpserted } = resultsSlice.actions; +export const { resultUpserted, resultRemoved } = resultsSlice.actions; export default resultsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts index 49e4d7e3ff..e7620cbc31 100644 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts @@ -9,8 +9,12 @@ import { receivedUploadImages, IMAGES_PER_PAGE } from 'services/thunks/gallery'; import { ImageDTO } from 'services/api'; import { dateComparator } from 'common/util/dateComparator'; -export type UploadsImageDTO = Omit & { - image_type: 'uploads'; +export type UploadsImageDTO = Omit< + ImageDTO, + 'image_origin' | 'image_category' +> & { + image_origin: 'external'; + image_category: 'user'; }; export const uploadsAdapter = createEntityAdapter({ @@ -45,6 +49,9 @@ const uploadsSlice = createSlice({ uploadsAdapter.upsertOne(state, action.payload); state.upsertedImageCount += 1; }, + uploadRemoved: (state, action: PayloadAction) => { + uploadsAdapter.removeOne(state, action.payload); + }, }, extraReducers: (builder) => { /** @@ -81,6 +88,6 @@ export const { selectTotal: selectUploadsTotal, } = uploadsAdapter.getSelectors((state) => state.uploads); -export const { uploadUpserted } = uploadsSlice.actions; +export const { uploadUpserted, uploadRemoved } = uploadsSlice.actions; export default uploadsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index 18be021625..e4a0f41ee1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder'; import { useGetUrl } from 'common/util/getUrl'; -import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName'; +import useGetImageByNameAndOrigin from 'features/gallery/hooks/useGetImageByName'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { @@ -11,7 +11,7 @@ import { } from 'features/nodes/types/types'; import { DragEvent, memo, useCallback, useState } from 'react'; -import { ImageType } from 'services/api'; +import { ResourceOrigin } from 'services/api'; import { FieldComponentProps } from './types'; const ImageInputFieldComponent = ( @@ -19,7 +19,7 @@ const ImageInputFieldComponent = ( ) => { const { nodeId, field } = props; - const getImageByNameAndType = useGetImageByNameAndType(); + const getImageByNameAndType = useGetImageByNameAndOrigin(); const dispatch = useAppDispatch(); const [url, setUrl] = useState(field.value?.image_url); const { getUrl } = useGetUrl(); @@ -27,7 +27,9 @@ const ImageInputFieldComponent = ( const handleDrop = useCallback( (e: DragEvent) => { const name = e.dataTransfer.getData('invokeai/imageName'); - const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; + const type = e.dataTransfer.getData( + 'invokeai/imageOrigin' + ) as ResourceOrigin; if (!name || !type) { return; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts index d9eb80d654..bd3d8a5460 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts @@ -64,7 +64,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => { model, image: { image_name: initialImage?.image_name, - image_type: initialImage?.image_type, + image_origin: initialImage?.image_origin, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts index 5f00d12a23..558f937837 100644 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts @@ -58,7 +58,7 @@ export const buildImg2ImgNode = ( imageToImageNode.image = { image_name: initialImage.name, - image_type: initialImage.type, + image_origin: initialImage.type, }; } diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts index b3f6cca933..0556a499be 100644 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts @@ -51,7 +51,7 @@ export const buildInpaintNode = ( inpaintNode.image = { image_name: initialImage.name, - image_type: initialImage.type, + image_origin: initialImage.type, }; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx index be40f548e6..a5b106163f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx @@ -5,7 +5,7 @@ import { useGetUrl } from 'common/util/getUrl'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { DragEvent, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { ImageType } from 'services/api'; +import { ResourceOrigin } from 'services/api'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { initialImageSelected } from 'features/parameters/store/actions'; @@ -55,9 +55,11 @@ const InitialImagePreview = () => { const handleDrop = useCallback( (e: DragEvent) => { const name = e.dataTransfer.getData('invokeai/imageName'); - const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; + const type = e.dataTransfer.getData( + 'invokeai/imageOrigin' + ) as ResourceOrigin; - dispatch(initialImageSelected({ image_name: name, image_type: type })); + dispatch(initialImageSelected({ image_name: name, image_origin: type })); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index 853597c809..6c1030b7b0 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -1,10 +1,10 @@ import { createAction } from '@reduxjs/toolkit'; import { isObject } from 'lodash-es'; -import { ImageDTO, ImageType } from 'services/api'; +import { ImageDTO, ResourceOrigin } from 'services/api'; -export type ImageNameAndType = { +export type ImageNameAndOrigin = { image_name: string; - image_type: ImageType; + image_origin: ResourceOrigin; }; export const isImageDTO = (image: any): image is ImageDTO => { @@ -13,8 +13,8 @@ export const isImageDTO = (image: any): image is ImageDTO => { isObject(image) && 'image_name' in image && image?.image_name !== undefined && - 'image_type' in image && - image?.image_type !== undefined && + 'image_origin' in image && + image?.image_origin !== undefined && 'image_url' in image && image?.image_url !== undefined && 'thumbnail_url' in image && @@ -27,5 +27,5 @@ export const isImageDTO = (image: any): image is ImageDTO => { }; export const initialImageSelected = createAction< - ImageDTO | ImageNameAndType | undefined + ImageDTO | ImageNameAndOrigin | undefined >('generation/initialImageSelected'); diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts index 03032a60ef..e6bb163167 100644 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ b/invokeai/frontend/web/src/services/thunks/gallery.ts @@ -23,8 +23,8 @@ export const receivedGalleryImages = createAppAsyncThunk< const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); const response = await ImagesService.listImagesWithMetadata({ + excludeCategories: ['user'], isIntermediate: false, - showInGallery: true, page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); @@ -53,9 +53,8 @@ export const receivedUploadImages = createAppAsyncThunk< const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); const response = await ImagesService.listImagesWithMetadata({ - imageType: 'uploads', + includeCategories: ['user'], isIntermediate: false, - showInGallery: false, page: nextPage + pageOffset, perPage: IMAGES_PER_PAGE, }); diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts index 266e991f4d..1231a38b4d 100644 --- a/invokeai/frontend/web/src/services/types/guards.ts +++ b/invokeai/frontend/web/src/services/types/guards.ts @@ -1,4 +1,3 @@ -import { ResultsImageDTO } from 'features/gallery/store/resultsSlice'; import { UploadsImageDTO } from 'features/gallery/store/uploadsSlice'; import { get, isObject, isString } from 'lodash-es'; import { @@ -9,17 +8,18 @@ import { PromptOutput, IterateInvocationOutput, CollectInvocationOutput, - ImageType, ImageField, LatentsOutput, ImageDTO, + ResourceOrigin, } from 'services/api'; -export const isUploadsImageDTO = (image: ImageDTO): image is UploadsImageDTO => - image.image_type === 'uploads'; - -export const isResultsImageDTO = (image: ImageDTO): image is ResultsImageDTO => - image.image_type === 'results'; +export const isUploadsImageDTO = ( + image: ImageDTO | undefined +): image is UploadsImageDTO => + image !== undefined && + image.image_origin === 'external' && + image.image_category === 'user'; export const isImageOutput = ( output: GraphExecutionState['results'][string] @@ -49,10 +49,10 @@ export const isCollectOutput = ( output: GraphExecutionState['results'][string] ): output is CollectInvocationOutput => output.type === 'collect_output'; -export const isImageType = (t: unknown): t is ImageType => - isString(t) && ['results', 'uploads', 'intermediates'].includes(t); +export const isResourceOrigin = (t: unknown): t is ResourceOrigin => + isString(t) && ['internal', 'external'].includes(t); export const isImageField = (imageField: unknown): imageField is ImageField => isObject(imageField) && isString(get(imageField, 'image_name')) && - isImageType(get(imageField, 'image_type')); + isResourceOrigin(get(imageField, 'image_origin')); From 08a14ee6d5ceeb962224a19fb2146c6045788933 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 21:55:29 +1000 Subject: [PATCH 12/34] fix(nodes): fix conflicts with controlnet --- .../app/invocations/controlnet_image_processors.py | 10 +++++----- invokeai/app/invocations/generate.py | 4 ++-- invokeai/app/invocations/latent.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 187784b29e..7d5160a491 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -7,7 +7,7 @@ from typing import Literal, Optional, Union, List from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field -from ..models.image import ImageField, ImageType, ImageCategory +from ..models.image import ImageField, ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -163,7 +163,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) @@ -177,8 +177,8 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery image_dto = context.services.images.create( image=processed_image, - image_type=ImageType.RESULT, - image_category=ImageCategory.GENERAL, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.CONTROL, session_id=context.graph_execution_state_id, node_id=self.id, is_intermediate=self.is_intermediate @@ -187,7 +187,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ) return ImageOutput( image=processed_image_field, diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index d2ce59d247..370e99d5b4 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -86,8 +86,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # loading controlnet image (currently requires pre-processed image) control_image = ( None if self.control_image is None - else context.services.images.get( - self.control_image.image_type, self.control_image.image_name + else context.services.images.get_pil_image( + self.control_image.image_origin, self.control_image.image_name ) ) # loading controlnet model diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 7085cfd308..58b0fdccbc 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -297,7 +297,7 @@ class TextToLatentsInvocation(BaseInvocation): torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_type, + input_image = context.services.images.get_pil_image(control_image_field.image_origin, control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes From 05b99b53776e8b468185c8f58f614889e98ec134 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 21:56:25 +1000 Subject: [PATCH 13/34] fix(ui): fix erroneously displays `is_intermediate` field on nodes --- invokeai/frontend/web/src/features/nodes/util/parseSchema.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 631552414d..c77fdeca5e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -13,7 +13,7 @@ import { buildOutputFieldTemplates, } from './fieldTemplateBuilders'; -const RESERVED_FIELD_NAMES = ['id', 'type', 'meta']; +const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; const invocationDenylist = ['Graph', 'InvocationMeta']; From 38fd2ad45d3a39fede35559731216a7b6fda797b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 27 May 2023 22:00:25 +1000 Subject: [PATCH 14/34] fix(ui): fix metadata viewer crash --- .../components/ImageMetaDataViewer/ImageMetadataViewer.tsx | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx index b01191105e..df52a06c90 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx @@ -53,6 +53,11 @@ const MetadataItem = ({ withCopy = false, }: MetadataItemProps) => { const { t } = useTranslation(); + + if (!value) { + return null; + } + return ( {onClick && ( From f31e62afad15f8b9983cf5d6f4754e90ca94bc7b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 28 May 2023 18:59:14 +1000 Subject: [PATCH 15/34] feat(nodes): make list images route use offset pagination Because we dynamically insert images into the DB and UI's images state, `page`/`per_page` pagination makes loading the images awkward. Using `offset`/`limit` pagination lets us query for images with an offset equal to the number of images already loaded (which match the query parameters). The result is that we always get the correct next page of images when loading more. --- invokeai/app/api/routers/images.py | 24 +++--- invokeai/app/services/image_record_storage.py | 81 ++++++++++--------- invokeai/app/services/images.py | 33 ++++---- invokeai/app/services/models/image_record.py | 11 ++- 4 files changed, 78 insertions(+), 71 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index f0399a2d07..ae10cce140 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -8,6 +8,7 @@ from invokeai.app.models.image import ( ImageCategory, ResourceOrigin, ) +from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.models.image_record import ( ImageDTO, ImageRecordChanges, @@ -221,35 +222,28 @@ async def get_image_urls( @images_router.get( "/", operation_id="list_images_with_metadata", - response_model=PaginatedResults[ImageDTO], + response_model=OffsetPaginatedResults[ImageDTO], ) async def list_images_with_metadata( image_origin: Optional[ResourceOrigin] = Query( default=None, description="The origin of images to list" ), - include_categories: Optional[list[ImageCategory]] = Query( + categories: Optional[list[ImageCategory]] = Query( default=None, description="The categories of image to include" ), - exclude_categories: Optional[list[ImageCategory]] = Query( - default=None, description="The categories of image to exclude" - ), is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), - page: int = Query(default=0, description="The page of images to get"), - per_page: int = Query(default=10, description="The number of images per page"), -) -> PaginatedResults[ImageDTO]: + offset: int = Query(default=0, description="The page offset"), + limit: int = Query(default=10, description="The number of images per page"), +) -> OffsetPaginatedResults[ImageDTO]: """Gets a list of images""" - if include_categories is not None and exclude_categories is not None: - raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.") - image_dtos = ApiDependencies.invoker.services.images.get_many( - page, - per_page, + offset, + limit, image_origin, - include_categories, - exclude_categories, + categories, is_intermediate, ) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 6b6d1ce7b2..c27596afac 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Optional, cast +from typing import Generic, Optional, TypeVar, cast import sqlite3 import threading from typing import Optional, Union +from pydantic import BaseModel, Field +from pydantic.generics import GenericModel + from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.image import ( ImageCategory, @@ -15,7 +18,18 @@ from invokeai.app.services.models.image_record import ( ImageRecordChanges, deserialize_image_record, ) -from invokeai.app.services.item_storage import PaginatedResults + +T = TypeVar("T", bound=BaseModel) + +class OffsetPaginatedResults(GenericModel, Generic[T]): + """Offset-paginated results""" + + # fmt: off + items: list[T] = Field(description="Items") + offset: int = Field(description="Offset from which to retrieve items") + limit: int = Field(description="Limit of items to get") + total: int = Field(description="Total number of items in result") + # fmt: on # TODO: Should these excpetions subclass existing python exceptions? @@ -63,13 +77,12 @@ class ImageRecordStorageBase(ABC): @abstractmethod def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageRecord]: + ) -> OffsetPaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -238,6 +251,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """, (changes.session_id, image_name), ) + + # Change the image's `is_intermediate`` flag + if changes.session_id is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET is_intermediate = ? + WHERE image_name = ?; + """, + (changes.is_intermediate, image_name), + ) self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -247,13 +271,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageRecord]: + ) -> OffsetPaginatedResults[ImageRecord]: try: self._lock.acquire() @@ -269,30 +292,18 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): query_conditions += f"""AND image_origin = ?\n""" query_params.append(image_origin.value) - if include_categories is not None: + if categories is not None: ## Convert the enum values to unique list of strings - include_category_strings = list( - map(lambda c: c.value, set(include_categories)) + category_strings = list( + map(lambda c: c.value, set(categories)) ) # Create the correct length of placeholders - placeholders = ",".join("?" * len(include_category_strings)) + placeholders = ",".join("?" * len(category_strings)) query_conditions += f"AND image_category IN ( {placeholders} )\n" # Unpack the included categories into the query params - query_params.append(*include_category_strings) - - if exclude_categories is not None: - ## Convert the enum values to unique list of strings - exclude_category_strings = list( - map(lambda c: c.value, set(exclude_categories)) - ) - - # Create the correct length of placeholders - placeholders = ",".join("?" * len(exclude_category_strings)) - query_conditions += f"AND image_category NOT IN ( {placeholders} )\n" - - # Unpack the included categories into the query params - query_params.append(*exclude_category_strings) + for c in category_strings: + query_params.append(c) if is_intermediate is not None: query_conditions += f"""AND is_intermediate = ?\n""" @@ -304,8 +315,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): images_query += query_conditions + query_pagination + ";" # Add all the parameters images_params = query_params.copy() - images_params.append(per_page) - images_params.append(page * per_page) + images_params.append(limit) + images_params.append(offset) # Build the list of images, deserializing each row self._cursor.execute(images_query, images_params) result = cast(list[sqlite3.Row], self._cursor.fetchall()) @@ -322,10 +333,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): finally: self._lock.release() - pageCount = int(count / per_page) + 1 - - return PaginatedResults( - items=images, page=page, pages=pageCount, per_page=per_page, total=count + return OffsetPaginatedResults( + items=images, offset=offset, limit=limit, total=count ) def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index dca95f673f..2618a9763e 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -15,6 +15,7 @@ from invokeai.app.services.image_record_storage import ( ImageRecordNotFoundException, ImageRecordSaveException, ImageRecordStorageBase, + OffsetPaginatedResults, ) from invokeai.app.services.models.image_record import ( ImageRecord, @@ -98,13 +99,12 @@ class ImageServiceABC(ABC): @abstractmethod def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageDTO]: + ) -> OffsetPaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @@ -328,20 +328,18 @@ class ImageService(ImageServiceABC): def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageDTO]: + ) -> OffsetPaginatedResults[ImageDTO]: try: results = self._services.records.get_many( - page, - per_page, + offset, + limit, image_origin, - include_categories, - exclude_categories, + categories, is_intermediate, ) @@ -358,11 +356,10 @@ class ImageService(ImageServiceABC): ) ) - return PaginatedResults[ImageDTO]( + return OffsetPaginatedResults[ImageDTO]( items=image_dtos, - page=results.page, - pages=results.pages, - per_page=results.per_page, + offset=results.offset, + limit=results.limit, total=results.total, ) except Exception as e: diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index f143a30928..051236b12b 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,6 +1,6 @@ import datetime from typing import Optional, Union -from pydantic import BaseModel, Extra, Field, StrictStr +from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -56,6 +56,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): Only limited changes are valid: - `image_category`: change the category of an image - `session_id`: change the session associated with an image + - `is_intermediate`: change the image's `is_intermediate` flag """ image_category: Optional[ImageCategory] = Field( @@ -67,6 +68,10 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): description="The image's new session ID.", ) """The image's new session ID.""" + is_intermediate: Optional[StrictBool] = Field( + default=None, description="The image's new `is_intermediate` flag." + ) + """The image's new `is_intermediate` flag.""" class ImageUrlsDTO(BaseModel): @@ -105,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # Retrieve all the values, setting "reasonable" defaults if they are not present. image_name = image_dict.get("image_name", "unknown") - image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)) + image_origin = ResourceOrigin( + image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) + ) image_category = ImageCategory( image_dict.get("image_category", ImageCategory.GENERAL.value) ) From 6cc00ef4b7da13b5f2f06d3fdd3ba6ab57664d47 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 28 May 2023 18:59:38 +1000 Subject: [PATCH 16/34] chore(ui): regen api client --- .../frontend/web/src/services/api/index.ts | 21 +++++++++++- .../models/CannyImageProcessorInvocation.ts | 4 +++ .../ContentShuffleImageProcessorInvocation.ts | 4 +++ .../api/models/ControlNetInvocation.ts | 6 +++- .../api/models/FloatCollectionOutput.ts | 15 +++++++++ .../src/services/api/models/FloatOutput.ts | 15 +++++++++ .../web/src/services/api/models/Graph.ts | 17 +++++++++- .../api/models/GraphExecutionState.ts | 5 ++- .../api/models/HedImageprocessorInvocation.ts | 4 +++ .../api/models/ImageProcessorInvocation.ts | 4 +++ .../services/api/models/ImageRecordChanges.ts | 5 +++ .../api/models/LatentsToLatentsInvocation.ts | 5 +++ .../LineartAnimeImageProcessorInvocation.ts | 4 +++ .../models/LineartImageProcessorInvocation.ts | 4 +++ .../MediapipeFaceProcessorInvocation.ts | 33 +++++++++++++++++++ .../MidasDepthImageProcessorInvocation.ts | 4 +++ .../models/MlsdImageProcessorInvocation.ts | 4 +++ .../NormalbaeImageProcessorInvocation.ts | 4 +++ ...ts => OffsetPaginatedResults_ImageDTO_.ts} | 16 ++++----- .../OpenposeImageProcessorInvocation.ts | 4 +++ .../api/models/ParamFloatInvocation.ts | 23 +++++++++++++ .../models/PidiImageProcessorInvocation.ts | 4 +++ .../api/models/TextToLatentsInvocation.ts | 5 +++ .../ZoeDepthImageProcessorInvocation.ts | 25 ++++++++++++++ .../services/api/services/ImagesService.ts | 32 ++++++++---------- .../services/api/services/SessionsService.ts | 19 +++++++++-- 26 files changed, 251 insertions(+), 35 deletions(-) create mode 100644 invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts create mode 100644 invokeai/frontend/web/src/services/api/models/FloatOutput.ts create mode 100644 invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts rename invokeai/frontend/web/src/services/api/models/{PaginatedResults_ImageDTO_.ts => OffsetPaginatedResults_ImageDTO_.ts} (56%) create mode 100644 invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts create mode 100644 invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index d9f00becd9..292cd5ce4d 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -8,6 +8,7 @@ export type { OpenAPIConfig } from './core/OpenAPI'; export type { AddInvocation } from './models/AddInvocation'; export type { Body_upload_image } from './models/Body_upload_image'; +export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation'; export type { CkptModelInfo } from './models/CkptModelInfo'; export type { CollectInvocation } from './models/CollectInvocation'; export type { CollectInvocationOutput } from './models/CollectInvocationOutput'; @@ -15,16 +16,23 @@ export type { ColorField } from './models/ColorField'; export type { CompelInvocation } from './models/CompelInvocation'; export type { CompelOutput } from './models/CompelOutput'; export type { ConditioningField } from './models/ConditioningField'; +export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; +export type { ControlField } from './models/ControlField'; +export type { ControlNetInvocation } from './models/ControlNetInvocation'; +export type { ControlOutput } from './models/ControlOutput'; export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DivideInvocation } from './models/DivideInvocation'; export type { Edge } from './models/Edge'; export type { EdgeConnection } from './models/EdgeConnection'; +export type { FloatCollectionOutput } from './models/FloatCollectionOutput'; +export type { FloatOutput } from './models/FloatOutput'; export type { Graph } from './models/Graph'; export type { GraphExecutionState } from './models/GraphExecutionState'; export type { GraphInvocation } from './models/GraphInvocation'; export type { GraphInvocationOutput } from './models/GraphInvocationOutput'; +export type { HedImageprocessorInvocation } from './models/HedImageprocessorInvocation'; export type { HTTPValidationError } from './models/HTTPValidationError'; export type { ImageBlurInvocation } from './models/ImageBlurInvocation'; export type { ImageCategory } from './models/ImageCategory'; @@ -39,6 +47,7 @@ export type { ImageMetadata } from './models/ImageMetadata'; export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation'; export type { ImageOutput } from './models/ImageOutput'; export type { ImagePasteInvocation } from './models/ImagePasteInvocation'; +export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation'; export type { ImageRecordChanges } from './models/ImageRecordChanges'; export type { ImageToImageInvocation } from './models/ImageToImageInvocation'; export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation'; @@ -55,16 +64,25 @@ export type { LatentsField } from './models/LatentsField'; export type { LatentsOutput } from './models/LatentsOutput'; export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation'; export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation'; +export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation'; +export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation'; export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskOutput } from './models/MaskOutput'; +export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; +export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation'; +export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation'; export type { ModelsList } from './models/ModelsList'; export type { MultiplyInvocation } from './models/MultiplyInvocation'; export type { NoiseInvocation } from './models/NoiseInvocation'; export type { NoiseOutput } from './models/NoiseOutput'; +export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation'; +export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_'; +export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; -export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_'; +export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; +export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; @@ -81,6 +99,7 @@ export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { VaeRepo } from './models/VaeRepo'; export type { ValidationError } from './models/ValidationError'; +export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; export { ImagesService } from './services/ImagesService'; export { ModelsService } from './services/ModelsService'; diff --git a/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts index 474f1d3f3c..3a8b0b21e7 100644 --- a/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type CannyImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'canny_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts index 4a07508be7..d8bc3fe58e 100644 --- a/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type ContentShuffleImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'content_shuffle_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts index e8372f43dd..fad3af911b 100644 --- a/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts @@ -12,6 +12,10 @@ export type ControlNetInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'controlnet'; /** * image to process @@ -20,7 +24,7 @@ export type ControlNetInvocation = { /** * control model used */ - control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace'; + control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace'; /** * weight given to controlnet */ diff --git a/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts b/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts new file mode 100644 index 0000000000..a3f08247a4 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * A collection of floats + */ +export type FloatCollectionOutput = { + type?: 'float_collection'; + /** + * The float collection + */ + collection?: Array; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/FloatOutput.ts b/invokeai/frontend/web/src/services/api/models/FloatOutput.ts new file mode 100644 index 0000000000..2331936b30 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/FloatOutput.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * A float output + */ +export type FloatOutput = { + type?: 'float_output'; + /** + * The output float + */ + param?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index 6be925841b..af8a3ed0e6 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -3,12 +3,16 @@ /* eslint-disable */ import type { AddInvocation } from './AddInvocation'; +import type { CannyImageProcessorInvocation } from './CannyImageProcessorInvocation'; import type { CollectInvocation } from './CollectInvocation'; import type { CompelInvocation } from './CompelInvocation'; +import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleImageProcessorInvocation'; +import type { ControlNetInvocation } from './ControlNetInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { DivideInvocation } from './DivideInvocation'; import type { Edge } from './Edge'; import type { GraphInvocation } from './GraphInvocation'; +import type { HedImageprocessorInvocation } from './HedImageprocessorInvocation'; import type { ImageBlurInvocation } from './ImageBlurInvocation'; import type { ImageChannelInvocation } from './ImageChannelInvocation'; import type { ImageConvertInvocation } from './ImageConvertInvocation'; @@ -17,6 +21,7 @@ import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation'; import type { ImageLerpInvocation } from './ImageLerpInvocation'; import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation'; import type { ImagePasteInvocation } from './ImagePasteInvocation'; +import type { ImageProcessorInvocation } from './ImageProcessorInvocation'; import type { ImageToImageInvocation } from './ImageToImageInvocation'; import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation'; import type { InfillColorInvocation } from './InfillColorInvocation'; @@ -26,11 +31,20 @@ import type { InpaintInvocation } from './InpaintInvocation'; import type { IterateInvocation } from './IterateInvocation'; import type { LatentsToImageInvocation } from './LatentsToImageInvocation'; import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation'; +import type { LineartAnimeImageProcessorInvocation } from './LineartAnimeImageProcessorInvocation'; +import type { LineartImageProcessorInvocation } from './LineartImageProcessorInvocation'; import type { LoadImageInvocation } from './LoadImageInvocation'; import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation'; +import type { MediapipeFaceProcessorInvocation } from './MediapipeFaceProcessorInvocation'; +import type { MidasDepthImageProcessorInvocation } from './MidasDepthImageProcessorInvocation'; +import type { MlsdImageProcessorInvocation } from './MlsdImageProcessorInvocation'; import type { MultiplyInvocation } from './MultiplyInvocation'; import type { NoiseInvocation } from './NoiseInvocation'; +import type { NormalbaeImageProcessorInvocation } from './NormalbaeImageProcessorInvocation'; +import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorInvocation'; +import type { ParamFloatInvocation } from './ParamFloatInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation'; +import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RangeInvocation } from './RangeInvocation'; @@ -43,6 +57,7 @@ import type { SubtractInvocation } from './SubtractInvocation'; import type { TextToImageInvocation } from './TextToImageInvocation'; import type { TextToLatentsInvocation } from './TextToLatentsInvocation'; import type { UpscaleInvocation } from './UpscaleInvocation'; +import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation'; export type Graph = { /** @@ -52,7 +67,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts index 8c2eb05657..ea41ce055b 100644 --- a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts +++ b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts @@ -4,6 +4,9 @@ import type { CollectInvocationOutput } from './CollectInvocationOutput'; import type { CompelOutput } from './CompelOutput'; +import type { ControlOutput } from './ControlOutput'; +import type { FloatCollectionOutput } from './FloatCollectionOutput'; +import type { FloatOutput } from './FloatOutput'; import type { Graph } from './Graph'; import type { GraphInvocationOutput } from './GraphInvocationOutput'; import type { ImageOutput } from './ImageOutput'; @@ -42,7 +45,7 @@ export type GraphExecutionState = { /** * The results of node executions */ - results: Record; + results: Record; /** * Errors raised when executing nodes */ diff --git a/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts index 6dea43dc32..f975f18968 100644 --- a/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts @@ -12,6 +12,10 @@ export type HedImageprocessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'hed_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts index 90639a0569..f972582e2f 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type ImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts index 51f0ee2079..e597cd907d 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts @@ -10,6 +10,7 @@ import type { ImageCategory } from './ImageCategory'; * Only limited changes are valid: * - `image_category`: change the category of an image * - `session_id`: change the session associated with an image + * - `is_intermediate`: change the image's `is_intermediate` flag */ export type ImageRecordChanges = { /** @@ -20,5 +21,9 @@ export type ImageRecordChanges = { * The image's new session ID. */ session_id?: string; + /** + * The image's new `is_intermediate` flag. + */ + is_intermediate?: boolean; }; diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts index 6436557f64..f5b4912141 100644 --- a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts @@ -3,6 +3,7 @@ /* eslint-disable */ import type { ConditioningField } from './ConditioningField'; +import type { ControlField } from './ControlField'; import type { LatentsField } from './LatentsField'; /** @@ -46,6 +47,10 @@ export type LatentsToLatentsInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * The control to use + */ + control?: (ControlField | Array); /** * The latents to use as a base image */ diff --git a/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts index a9bdab56ec..4796d2a049 100644 --- a/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type LineartAnimeImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'lineart_anime_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts index 1aa931525f..8328849b50 100644 --- a/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type LineartImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'lineart_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts new file mode 100644 index 0000000000..bd223eed7d --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies mediapipe face processing to image + */ +export type MediapipeFaceProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'mediapipe_face_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * maximum number of faces to detect + */ + max_faces?: number; + /** + * minimum confidence for face detection + */ + min_confidence?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts index 71283b0614..11023086a2 100644 --- a/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type MidasDepthImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'midas_depth_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts index 85a2ad15cc..c2d4a61b9a 100644 --- a/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type MlsdImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'mlsd_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts index 519ea7a89d..ecfb50a09f 100644 --- a/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type NormalbaeImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'normalbae_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts similarity index 56% rename from invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts rename to invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts index 5d2bdae5ab..3408bea6db 100644 --- a/invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts +++ b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts @@ -5,25 +5,21 @@ import type { ImageDTO } from './ImageDTO'; /** - * Paginated results + * Offset-paginated results */ -export type PaginatedResults_ImageDTO_ = { +export type OffsetPaginatedResults_ImageDTO_ = { /** * Items */ items: Array; /** - * Current Page + * Offset from which to retrieve items */ - page: number; + offset: number; /** - * Total number of pages + * Limit of items to get */ - pages: number; - /** - * Number of items per page - */ - per_page: number; + limit: number; /** * Total number of items in result */ diff --git a/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts index 44947df15b..5af21d542e 100644 --- a/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type OpenposeImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'openpose_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts b/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts new file mode 100644 index 0000000000..87c01f847f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts @@ -0,0 +1,23 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * A float parameter + */ +export type ParamFloatInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'param_float'; + /** + * The float value + */ + param?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts index 59076cb2e1..a08bf6a920 100644 --- a/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts @@ -12,6 +12,10 @@ export type PidiImageProcessorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'pidi_image_processor'; /** * image to process diff --git a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts index 33eedc0f02..f1831b2b59 100644 --- a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts @@ -3,6 +3,7 @@ /* eslint-disable */ import type { ConditioningField } from './ConditioningField'; +import type { ControlField } from './ControlField'; import type { LatentsField } from './LatentsField'; /** @@ -46,5 +47,9 @@ export type TextToLatentsInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * The control to use + */ + control?: (ControlField | Array); }; diff --git a/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts new file mode 100644 index 0000000000..55d05f3167 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts @@ -0,0 +1,25 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies Zoe depth processing to image + */ +export type ZoeDepthImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'zoe_depth_image_processor'; + /** + * image to process + */ + image?: ImageField; +}; + diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index 379f9f7dd7..51fe6c820f 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -6,7 +6,7 @@ import type { ImageCategory } from '../models/ImageCategory'; import type { ImageDTO } from '../models/ImageDTO'; import type { ImageRecordChanges } from '../models/ImageRecordChanges'; import type { ImageUrlsDTO } from '../models/ImageUrlsDTO'; -import type { PaginatedResults_ImageDTO_ } from '../models/PaginatedResults_ImageDTO_'; +import type { OffsetPaginatedResults_ImageDTO_ } from '../models/OffsetPaginatedResults_ImageDTO_'; import type { ResourceOrigin } from '../models/ResourceOrigin'; import type { CancelablePromise } from '../core/CancelablePromise'; @@ -18,16 +18,15 @@ export class ImagesService { /** * List Images With Metadata * Gets a list of images - * @returns PaginatedResults_ImageDTO_ Successful Response + * @returns OffsetPaginatedResults_ImageDTO_ Successful Response * @throws ApiError */ public static listImagesWithMetadata({ imageOrigin, - includeCategories, - excludeCategories, + categories, isIntermediate, - page, - perPage = 10, + offset, + limit = 10, }: { /** * The origin of images to list @@ -36,34 +35,29 @@ export class ImagesService { /** * The categories of image to include */ - includeCategories?: Array, - /** - * The categories of image to exclude - */ - excludeCategories?: Array, + categories?: Array, /** * Whether to list intermediate images */ isIntermediate?: boolean, /** - * The page of images to get + * The page offset */ - page?: number, + offset?: number, /** * The number of images per page */ - perPage?: number, - }): CancelablePromise { + limit?: number, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/', query: { 'image_origin': imageOrigin, - 'include_categories': includeCategories, - 'exclude_categories': excludeCategories, + 'categories': categories, 'is_intermediate': isIntermediate, - 'page': page, - 'per_page': perPage, + 'offset': offset, + 'limit': limit, }, errors: { 422: `Validation Error`, diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index 1c55d36502..de46d8fd3e 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -2,14 +2,18 @@ /* tslint:disable */ /* eslint-disable */ import type { AddInvocation } from '../models/AddInvocation'; +import type { CannyImageProcessorInvocation } from '../models/CannyImageProcessorInvocation'; import type { CollectInvocation } from '../models/CollectInvocation'; import type { CompelInvocation } from '../models/CompelInvocation'; +import type { ContentShuffleImageProcessorInvocation } from '../models/ContentShuffleImageProcessorInvocation'; +import type { ControlNetInvocation } from '../models/ControlNetInvocation'; import type { CvInpaintInvocation } from '../models/CvInpaintInvocation'; import type { DivideInvocation } from '../models/DivideInvocation'; import type { Edge } from '../models/Edge'; import type { Graph } from '../models/Graph'; import type { GraphExecutionState } from '../models/GraphExecutionState'; import type { GraphInvocation } from '../models/GraphInvocation'; +import type { HedImageprocessorInvocation } from '../models/HedImageprocessorInvocation'; import type { ImageBlurInvocation } from '../models/ImageBlurInvocation'; import type { ImageChannelInvocation } from '../models/ImageChannelInvocation'; import type { ImageConvertInvocation } from '../models/ImageConvertInvocation'; @@ -18,6 +22,7 @@ import type { ImageInverseLerpInvocation } from '../models/ImageInverseLerpInvoc import type { ImageLerpInvocation } from '../models/ImageLerpInvocation'; import type { ImageMultiplyInvocation } from '../models/ImageMultiplyInvocation'; import type { ImagePasteInvocation } from '../models/ImagePasteInvocation'; +import type { ImageProcessorInvocation } from '../models/ImageProcessorInvocation'; import type { ImageToImageInvocation } from '../models/ImageToImageInvocation'; import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation'; import type { InfillColorInvocation } from '../models/InfillColorInvocation'; @@ -27,12 +32,21 @@ import type { InpaintInvocation } from '../models/InpaintInvocation'; import type { IterateInvocation } from '../models/IterateInvocation'; import type { LatentsToImageInvocation } from '../models/LatentsToImageInvocation'; import type { LatentsToLatentsInvocation } from '../models/LatentsToLatentsInvocation'; +import type { LineartAnimeImageProcessorInvocation } from '../models/LineartAnimeImageProcessorInvocation'; +import type { LineartImageProcessorInvocation } from '../models/LineartImageProcessorInvocation'; import type { LoadImageInvocation } from '../models/LoadImageInvocation'; import type { MaskFromAlphaInvocation } from '../models/MaskFromAlphaInvocation'; +import type { MediapipeFaceProcessorInvocation } from '../models/MediapipeFaceProcessorInvocation'; +import type { MidasDepthImageProcessorInvocation } from '../models/MidasDepthImageProcessorInvocation'; +import type { MlsdImageProcessorInvocation } from '../models/MlsdImageProcessorInvocation'; import type { MultiplyInvocation } from '../models/MultiplyInvocation'; import type { NoiseInvocation } from '../models/NoiseInvocation'; +import type { NormalbaeImageProcessorInvocation } from '../models/NormalbaeImageProcessorInvocation'; +import type { OpenposeImageProcessorInvocation } from '../models/OpenposeImageProcessorInvocation'; import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedResults_GraphExecutionState_'; +import type { ParamFloatInvocation } from '../models/ParamFloatInvocation'; import type { ParamIntInvocation } from '../models/ParamIntInvocation'; +import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation'; @@ -45,6 +59,7 @@ import type { SubtractInvocation } from '../models/SubtractInvocation'; import type { TextToImageInvocation } from '../models/TextToImageInvocation'; import type { TextToLatentsInvocation } from '../models/TextToLatentsInvocation'; import type { UpscaleInvocation } from '../models/UpscaleInvocation'; +import type { ZoeDepthImageProcessorInvocation } from '../models/ZoeDepthImageProcessorInvocation'; import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; @@ -154,7 +169,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -191,7 +206,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', From 89aa06e0144f54d37b408258698ab30d45d66c4b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 28 May 2023 19:05:34 +1000 Subject: [PATCH 17/34] feat(ui): consolidate images slice Now that images are in a database and we can make filtered queries, we can do away with the cumbersome `resultsSlice` and `uploadsSlice`. - Remove `resultsSlice` and `uploadsSlice` entirely - Add `imagesSlice` fills the same role - Convert the application to use `imagesSlice`, reducing a lot of messy logic where we had to check which category was selected - Add a simple filter popover to the gallery, which lets you select any number of image categories --- .../enhancers/reduxRemember/serialize.ts | 4 - .../enhancers/reduxRemember/unserialize.ts | 6 +- .../middleware/listenerMiddleware/index.ts | 22 +-- .../listeners/canvasSavedToGallery.ts | 4 +- .../listeners/imageDeleted.ts | 41 ++-- .../listeners/imageMetadataReceived.ts | 19 +- .../listeners/imageUploaded.ts | 21 +- .../listeners/imageUrlsReceived.ts | 31 +-- .../listeners/initialImageSelected.ts | 16 +- ...lleryImages.ts => receivedPageOfImages.ts} | 16 +- .../listeners/receivedUploadImages.ts | 33 ---- .../listeners/socketio/invocationComplete.ts | 1 - .../listeners/socketio/socketConnected.ts | 16 +- .../listeners/stagingAreaImageSaved.ts | 54 ++++++ invokeai/frontend/web/src/app/store/store.ts | 8 +- .../IAICanvasStagingAreaToolbar.tsx | 12 +- .../web/src/features/canvas/store/actions.ts | 5 + .../components/CurrentImagePreview.tsx | 1 - .../gallery/components/HoverableImage.tsx | 1 - .../components/ImageGalleryContent.tsx | 182 ++++++------------ .../components/NextPrevImageButtons.tsx | 31 +-- .../gallery/hooks/useGetImageByName.ts | 41 ++-- .../gallery/store/galleryPersistDenylist.ts | 1 - .../features/gallery/store/gallerySlice.ts | 64 ------ .../src/features/gallery/store/imagesSlice.ts | 127 ++++++++++++ .../gallery/store/resultsPersistDenylist.ts | 8 - .../features/gallery/store/resultsSlice.ts | 91 --------- .../gallery/store/uploadsPersistDenylist.ts | 8 - .../features/gallery/store/uploadsSlice.ts | 93 --------- .../fields/ImageInputFieldComponent.tsx | 17 +- .../util/graphBuilders/buildCanvasGraph.ts | 11 +- .../ImageToImage/InitialImagePreview.tsx | 7 +- .../parameters/hooks/useParameters.ts | 2 +- .../src/features/parameters/store/actions.ts | 6 +- .../parameters/store/generationSelectors.ts | 31 --- .../web/src/services/thunks/gallery.ts | 64 ------ .../frontend/web/src/services/thunks/image.ts | 31 ++- .../frontend/web/src/services/types/guards.ts | 9 - 38 files changed, 395 insertions(+), 740 deletions(-) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/{receivedGalleryImages.ts => receivedPageOfImages.ts} (53%) delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts create mode 100644 invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts delete mode 100644 invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts delete mode 100644 invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts delete mode 100644 invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts delete mode 100644 invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts delete mode 100644 invokeai/frontend/web/src/services/thunks/gallery.ts diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index 52995e0da3..9fb4ceae32 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -1,7 +1,5 @@ import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; -import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist'; -import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist'; import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; @@ -22,11 +20,9 @@ const serializationDenylist: { models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, - results: resultsPersistDenylist, system: systemPersistDenylist, // config: configPersistDenyList, ui: uiPersistDenylist, - uploads: uploadsPersistDenylist, // hotkeys: hotkeysPersistDenylist, }; diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 155a7786b3..c6ae4946f2 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -1,7 +1,6 @@ import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice'; -import { initialResultsState } from 'features/gallery/store/resultsSlice'; -import { initialUploadsState } from 'features/gallery/store/uploadsSlice'; +import { initialImagesState } from 'features/gallery/store/imagesSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; @@ -24,12 +23,11 @@ const initialStates: { models: initialModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, - results: initialResultsState, system: initialSystemState, config: initialConfigState, ui: initialUIState, - uploads: initialUploadsState, hotkeys: initialHotkeysState, + images: initialImagesState, }; export const unserialize: UnserializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 7159957efa..6cc9867bfd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -59,18 +59,15 @@ import { addSessionCanceledPendingListener, addSessionCanceledRejectedListener, } from './listeners/sessionCanceled'; -import { - addReceivedGalleryImagesFulfilledListener, - addReceivedGalleryImagesRejectedListener, -} from './listeners/receivedGalleryImages'; -import { - addReceivedUploadImagesPageFulfilledListener, - addReceivedUploadImagesPageRejectedListener, -} from './listeners/receivedUploadImages'; import { addImageUpdatedFulfilledListener, addImageUpdatedRejectedListener, } from './listeners/imageUpdated'; +import { + addReceivedPageOfImagesFulfilledListener, + addReceivedPageOfImagesRejectedListener, +} from './listeners/receivedPageOfImages'; +import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; export const listenerMiddleware = createListenerMiddleware(); @@ -127,6 +124,7 @@ addCanvasSavedToGalleryListener(); addCanvasDownloadedAsImageListener(); addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); +addStagingAreaImageSavedListener(); // socketio addGeneratorProgressListener(); @@ -154,8 +152,6 @@ addSessionCanceledPendingListener(); addSessionCanceledFulfilledListener(); addSessionCanceledRejectedListener(); -// Gallery pages -addReceivedGalleryImagesFulfilledListener(); -addReceivedGalleryImagesRejectedListener(); -addReceivedUploadImagesPageFulfilledListener(); -addReceivedUploadImagesPageRejectedListener(); +// Images +addReceivedPageOfImagesFulfilledListener(); +addReceivedPageOfImagesRejectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index 01f097cdd1..a692a90670 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -5,7 +5,7 @@ import { imageUploaded } from 'services/thunks/image'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { addToast } from 'features/system/store/systemSlice'; import { v4 as uuidv4 } from 'uuid'; -import { resultUpserted } from 'features/gallery/store/resultsSlice'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); @@ -47,7 +47,7 @@ export const addCanvasSavedToGalleryListener = () => { action.meta.arg.formData.file.name === filename ); - dispatch(resultUpserted(uploadedImageDTO)); + dispatch(imageUpserted(uploadedImageDTO)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 7bd92e7e13..bf7ca4020c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -5,14 +5,11 @@ import { log } from 'app/logging/useLogger'; import { clamp } from 'lodash-es'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { - uploadRemoved, - uploadsAdapter, -} from 'features/gallery/store/uploadsSlice'; -import { - resultRemoved, - resultsAdapter, -} from 'features/gallery/store/resultsSlice'; -import { isUploadsImageDTO } from 'services/types/guards'; + imageRemoved, + imagesAdapter, + selectImagesEntities, + selectImagesIds, +} from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); @@ -33,19 +30,16 @@ export const addRequestedImageDeletionListener = () => { const state = getState(); const selectedImage = state.gallery.selectedImage; - const isUserImage = isUploadsImageDTO(selectedImage); + if (selectedImage && selectedImage.image_name === image_name) { - const allIds = isUserImage ? state.uploads.ids : state.results.ids; + const ids = selectImagesIds(state); + const entities = selectImagesEntities(state); - const allEntities = isUserImage - ? state.uploads.entities - : state.results.entities; - - const deletedImageIndex = allIds.findIndex( + const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name ); - const filteredIds = allIds.filter((id) => id.toString() !== image_name); + const filteredIds = ids.filter((id) => id.toString() !== image_name); const newSelectedImageIndex = clamp( deletedImageIndex, @@ -55,7 +49,7 @@ export const addRequestedImageDeletionListener = () => { const newSelectedImageId = filteredIds[newSelectedImageIndex]; - const newSelectedImage = allEntities[newSelectedImageId]; + const newSelectedImage = entities[newSelectedImageId]; if (newSelectedImageId) { dispatch(imageSelected(newSelectedImage)); @@ -64,11 +58,7 @@ export const addRequestedImageDeletionListener = () => { } } - if (isUserImage) { - dispatch(uploadRemoved(image_name)); - } else { - dispatch(resultRemoved(image_name)); - } + dispatch(imageRemoved(image_name)); dispatch( imageDeleted({ imageName: image_name, imageOrigin: image_origin }) @@ -86,12 +76,7 @@ export const addImageDeletedPendingListener = () => { effect: (action, { dispatch, getState }) => { const { imageName, imageOrigin } = action.meta.arg; // Preemptively remove the image from the gallery - if (imageOrigin === 'external') { - uploadsAdapter.removeOne(getState().uploads, imageName); - } - if (imageOrigin === 'internal') { - resultsAdapter.removeOne(getState().results, imageName); - } + imagesAdapter.removeOne(getState().images, imageName); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts index 276ef7be6c..63aeecb95e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts @@ -1,9 +1,7 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; import { imageMetadataReceived } from 'services/thunks/image'; -import { resultUpserted } from 'features/gallery/store/resultsSlice'; -import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'image' }); @@ -11,16 +9,13 @@ export const addImageMetadataReceivedFulfilledListener = () => { startAppListening({ actionCreator: imageMetadataReceived.fulfilled, effect: (action, { getState, dispatch }) => { - const imageDTO = action.payload; - moduleLog.debug({ data: { imageDTO } }, 'Image metadata received'); - - if (imageDTO.image_origin === 'internal') { - dispatch(resultUpserted(imageDTO)); - } - - if (imageDTO.image_origin === 'external') { - dispatch(uploadUpserted(imageDTO)); + const image = action.payload; + if (image.is_intermediate) { + // No further actions needed for intermediate images + return; } + moduleLog.debug({ data: { image } }, 'Image metadata received'); + dispatch(imageUpserted(image)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index dcce86017e..6d84431f80 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,13 +1,8 @@ import { startAppListening } from '..'; -import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; -import { - imageSelected, - setCurrentCategory, -} from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; import { addToast } from 'features/system/store/systemSlice'; -import { resultUpserted } from 'features/gallery/store/resultsSlice'; import { log } from 'app/logging/useLogger'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'image' }); @@ -26,18 +21,8 @@ export const addImageUploadedFulfilledListener = () => { const state = getState(); - // Handle uploads - if (image.image_category === 'user' && !image.is_intermediate) { - dispatch(uploadUpserted(image)); - dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); - } - - // Handle results - // TODO: Can this ever happen? I don't think so... - if (image.image_category !== 'user' && !image.is_intermediate) { - dispatch(resultUpserted(image)); - dispatch(setCurrentCategory('results')); - } + dispatch(imageUpserted(image)); + dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts index 588d7611cc..fd0461f893 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts @@ -1,8 +1,7 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; import { imageUrlsReceived } from 'services/thunks/image'; -import { resultsAdapter } from 'features/gallery/store/resultsSlice'; -import { uploadsAdapter } from 'features/gallery/store/uploadsSlice'; +import { imagesAdapter } from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'image' }); @@ -13,27 +12,15 @@ export const addImageUrlsReceivedFulfilledListener = () => { const image = action.payload; moduleLog.debug({ data: { image } }, 'Image URLs received'); - const { image_origin, image_name, image_url, thumbnail_url } = image; + const { image_name, image_url, thumbnail_url } = image; - if (image_origin === 'results') { - resultsAdapter.updateOne(getState().results, { - id: image_name, - changes: { - image_url, - thumbnail_url, - }, - }); - } - - if (image_origin === 'uploads') { - uploadsAdapter.updateOne(getState().uploads, { - id: image_name, - changes: { - image_url, - thumbnail_url, - }, - }); - } + imagesAdapter.updateOne(getState().images, { + id: image_name, + changes: { + image_url, + thumbnail_url, + }, + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts index a2e783a38a..940cc84c1e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts @@ -1,6 +1,4 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { selectResultsById } from 'features/gallery/store/resultsSlice'; -import { selectUploadsById } from 'features/gallery/store/uploadsSlice'; import { t } from 'i18next'; import { addToast } from 'features/system/store/systemSlice'; import { startAppListening } from '..'; @@ -9,7 +7,7 @@ import { isImageDTO, } from 'features/parameters/store/actions'; import { makeToast } from 'app/components/Toaster'; -import { ImageDTO } from 'services/api'; +import { selectImagesById } from 'features/gallery/store/imagesSlice'; export const addInitialImageSelectedListener = () => { startAppListening({ @@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => { return; } - const { image_name, image_origin } = action.payload; - - let image: ImageDTO | undefined; - const state = getState(); - - if (image_origin === 'results') { - image = selectResultsById(state, image_name); - } else if (image_origin === 'uploads') { - image = selectUploadsById(state, image_name); - } + const imageName = action.payload; + const image = selectImagesById(getState(), imageName); if (!image) { dispatch( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedGalleryImages.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts similarity index 53% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedGalleryImages.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts index aba81e1e72..9a2ec0e7a5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedGalleryImages.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts @@ -1,31 +1,31 @@ import { log } from 'app/logging/useLogger'; import { startAppListening } from '..'; -import { receivedGalleryImages } from 'services/thunks/gallery'; import { serializeError } from 'serialize-error'; +import { receivedPageOfImages } from 'services/thunks/image'; const moduleLog = log.child({ namespace: 'gallery' }); -export const addReceivedGalleryImagesFulfilledListener = () => { +export const addReceivedPageOfImagesFulfilledListener = () => { startAppListening({ - actionCreator: receivedGalleryImages.fulfilled, + actionCreator: receivedPageOfImages.fulfilled, effect: (action, { getState, dispatch }) => { const page = action.payload; moduleLog.debug( { data: { page } }, - `Received ${page.items.length} gallery images` + `Received ${page.items.length} images` ); }, }); }; -export const addReceivedGalleryImagesRejectedListener = () => { +export const addReceivedPageOfImagesRejectedListener = () => { startAppListening({ - actionCreator: receivedGalleryImages.rejected, + actionCreator: receivedPageOfImages.rejected, effect: (action, { getState, dispatch }) => { if (action.payload) { moduleLog.debug( - { data: { error: serializeError(action.payload.error) } }, - 'Problem receiving gallery images' + { data: { error: serializeError(action.payload) } }, + 'Problem receiving images' ); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts deleted file mode 100644 index 602fccf847..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedUploadImages.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { log } from 'app/logging/useLogger'; -import { startAppListening } from '..'; -import { receivedUploadImages } from 'services/thunks/gallery'; -import { serializeError } from 'serialize-error'; - -const moduleLog = log.child({ namespace: 'gallery' }); - -export const addReceivedUploadImagesPageFulfilledListener = () => { - startAppListening({ - actionCreator: receivedUploadImages.fulfilled, - effect: (action, { getState, dispatch }) => { - const page = action.payload; - moduleLog.debug( - { data: { page } }, - `Received ${page.items.length} uploaded images` - ); - }, - }); -}; - -export const addReceivedUploadImagesPageRejectedListener = () => { - startAppListening({ - actionCreator: receivedUploadImages.rejected, - effect: (action, { getState, dispatch }) => { - if (action.payload) { - moduleLog.debug( - { data: { error: serializeError(action.payload.error) } }, - 'Problem receiving uploaded images' - ); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts index 81c0286e3b..fb2056ae35 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts @@ -6,7 +6,6 @@ import { imageMetadataReceived } from 'services/thunks/image'; import { sessionCanceled } from 'services/thunks/session'; import { isImageOutput } from 'services/types/guards'; import { progressImageSet } from 'features/system/store/systemSlice'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; const moduleLog = log.child({ namespace: 'socketio' }); const nodeDenylist = ['dataURL_image']; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 650918ba3c..85035e6bf9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,10 +1,7 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; import { socketConnected } from 'services/events/actions'; -import { - receivedGalleryImages, - receivedUploadImages, -} from 'services/thunks/gallery'; +import { receivedPageOfImages } from 'services/thunks/image'; import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; @@ -18,17 +15,12 @@ export const addSocketConnectedListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { results, uploads, models, nodes, config } = getState(); + const { models, nodes, config, images } = getState(); const { disabledTabs } = config; - // These thunks need to be dispatch in middleware; cannot handle in a reducer - if (!results.ids.length) { - dispatch(receivedGalleryImages()); - } - - if (!uploads.ids.length) { - dispatch(receivedUploadImages()); + if (!images.ids.length) { + dispatch(receivedPageOfImages()); } if (!models.ids.length) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts new file mode 100644 index 0000000000..9bd3cd6dd2 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts @@ -0,0 +1,54 @@ +import { stagingAreaImageSaved } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { imageUpdated } from 'services/thunks/image'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; +import { addToast } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'canvas' }); + +export const addStagingAreaImageSavedListener = () => { + startAppListening({ + actionCreator: stagingAreaImageSaved, + effect: async (action, { dispatch, getState, take }) => { + const { image_name, image_origin } = action.payload; + + dispatch( + imageUpdated({ + imageName: image_name, + imageOrigin: image_origin, + requestBody: { + is_intermediate: false, + }, + }) + ); + + const [imageUpdatedAction] = await take( + (action) => + (imageUpdated.fulfilled.match(action) || + imageUpdated.rejected.match(action)) && + action.meta.arg.imageName === image_name + ); + + if (imageUpdated.rejected.match(imageUpdatedAction)) { + moduleLog.error( + { data: { arg: imageUpdatedAction.meta.arg } }, + 'Image saving failed' + ); + dispatch( + addToast({ + title: 'Image Saving Failed', + description: imageUpdatedAction.error.message, + status: 'error', + }) + ); + return; + } + + if (imageUpdated.fulfilled.match(imageUpdatedAction)) { + dispatch(imageUpserted(imageUpdatedAction.payload)); + dispatch(addToast({ title: 'Image Saved', status: 'success' })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 4e9c154f3a..521610adcc 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -10,8 +10,7 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares'; import canvasReducer from 'features/canvas/store/canvasSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; -import resultsReducer from 'features/gallery/store/resultsSlice'; -import uploadsReducer from 'features/gallery/store/uploadsSlice'; +import imagesReducer from 'features/gallery/store/imagesSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; @@ -41,12 +40,11 @@ const allReducers = { models: modelsReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, - results: resultsReducer, system: systemReducer, config: configReducer, ui: uiReducer, - uploads: uploadsReducer, hotkeys: hotkeysReducer, + images: imagesReducer, // session: sessionReducer, }; @@ -65,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'system', 'ui', // 'hotkeys', - // 'results', - // 'uploads', // 'config', ]; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx index 64c752fce0..68bc15bbaa 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx @@ -1,6 +1,5 @@ import { ButtonGroup, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -// import { saveStagingAreaImageToGallery } from 'app/socketio/actions'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { canvasSelector } from 'features/canvas/store/canvasSelectors'; @@ -26,6 +25,7 @@ import { FaPlus, FaSave, } from 'react-icons/fa'; +import { stagingAreaImageSaved } from '../store/actions'; const selector = createSelector( [canvasSelector], @@ -157,19 +157,15 @@ const IAICanvasStagingAreaToolbar = () => { } colorScheme="accent" /> - {/* } onClick={() => - dispatch( - saveStagingAreaImageToGallery( - currentStagingAreaImage.image.image_url - ) - ) + dispatch(stagingAreaImageSaved(currentStagingAreaImage.image)) } colorScheme="accent" - /> */} + /> ( + 'canvas/stagingAreaImageSaved' +); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index 38c104a83d..280d859b87 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -62,7 +62,6 @@ const CurrentImagePreview = () => { return; } e.dataTransfer.setData('invokeai/imageName', image.image_name); - e.dataTransfer.setData('invokeai/imageOrigin', image.image_origin); e.dataTransfer.effectAllowed = 'move'; }, [image] diff --git a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx index 4a51580650..94b653af1c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx @@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => { const handleDragStart = useCallback( (e: DragEvent) => { e.dataTransfer.setData('invokeai/imageName', image.image_name); - e.dataTransfer.setData('invokeai/imageOrigin', image.image_origin); e.dataTransfer.effectAllowed = 'move'; }, [image] diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 7c7fd29038..4b1786168d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -1,6 +1,7 @@ import { Box, - ButtonGroup, + Checkbox, + CheckboxGroup, Flex, FlexProps, Grid, @@ -16,7 +17,6 @@ import IAIPopover from 'common/components/IAIPopover'; import IAISlider from 'common/components/IAISlider'; import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { - setCurrentCategory, setGalleryImageMinimumWidth, setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, @@ -36,54 +36,53 @@ import { } from 'react'; import { useTranslation } from 'react-i18next'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; -import { FaImage, FaUser, FaWrench } from 'react-icons/fa'; +import { FaFilter, FaWrench } from 'react-icons/fa'; import { MdPhotoLibrary } from 'react-icons/md'; import HoverableImage from './HoverableImage'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { resultsAdapter } from '../store/resultsSlice'; -import { - receivedGalleryImages, - receivedUploadImages, -} from 'services/thunks/gallery'; -import { uploadsAdapter } from '../store/uploadsSlice'; import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import GalleryProgressImage from './GalleryProgressImage'; import { uiSelector } from 'features/ui/store/uiSelectors'; -import { ImageDTO } from 'services/api'; +import { ImageCategory, ImageDTO } from 'services/api'; +import { imageCategoriesChanged, selectImagesAll } from '../store/imagesSlice'; +import { receivedPageOfImages } from 'services/thunks/image'; +import { capitalize } from 'lodash-es'; -const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290; const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER'; +const IMAGE_CATEGORIES: ImageCategory[] = [ + 'general', + 'control', + 'mask', + 'user', + 'other', +]; const categorySelector = createSelector( [(state: RootState) => state], (state) => { - const { results, uploads, system, gallery } = state; - const { currentCategory } = gallery; + const { system, images } = state; + const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = []; - if (currentCategory === 'results') { - const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = []; - - if (system.progressImage) { - tempImages.push(PROGRESS_IMAGE_PLACEHOLDER); - } - - return { - images: tempImages.concat( - resultsAdapter.getSelectors().selectAll(results) - ), - isLoading: results.isLoading, - areMoreImagesAvailable: results.page < results.pages - 1, - }; + if (system.progressImage) { + tempImages.push(PROGRESS_IMAGE_PLACEHOLDER); } + const { categories } = images; + + const allImages = selectImagesAll(state); + const filteredImages = allImages.filter((i) => + categories.includes(i.image_category) + ); + return { - images: uploadsAdapter.getSelectors().selectAll(uploads), - isLoading: uploads.isLoading, - areMoreImagesAvailable: uploads.page < uploads.pages - 1, + images: tempImages.concat(filteredImages), + isLoading: images.isLoading, + areMoreImagesAvailable: filteredImages.length < images.total, + categories: images.categories, }; }, defaultSelectorOptions @@ -93,7 +92,6 @@ const mainSelector = createSelector( [gallerySelector, uiSelector], (gallery, ui) => { const { - currentCategory, galleryImageMinimumWidth, galleryImageObjectFit, shouldAutoSwitchToNewImages, @@ -104,7 +102,6 @@ const mainSelector = createSelector( const { shouldPinGallery } = ui; return { - currentCategory, shouldPinGallery, galleryImageMinimumWidth, galleryImageObjectFit, @@ -120,7 +117,6 @@ const ImageGalleryContent = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const resizeObserverRef = useRef(null); - const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true); const rootRef = useRef(null); const [scroller, setScroller] = useState(null); const [initialize, osInstance] = useOverlayScrollbars({ @@ -137,7 +133,6 @@ const ImageGalleryContent = () => { }); const { - currentCategory, shouldPinGallery, galleryImageMinimumWidth, galleryImageObjectFit, @@ -146,18 +141,12 @@ const ImageGalleryContent = () => { selectedImage, } = useAppSelector(mainSelector); - const { images, areMoreImagesAvailable, isLoading } = + const { images, areMoreImagesAvailable, isLoading, categories } = useAppSelector(categorySelector); - const handleClickLoadMore = () => { - if (currentCategory === 'results') { - dispatch(receivedGalleryImages()); - } - - if (currentCategory === 'uploads') { - dispatch(receivedUploadImages()); - } - }; + const handleLoadMoreImages = useCallback(() => { + dispatch(receivedPageOfImages()); + }, [dispatch]); const handleChangeGalleryImageMinimumWidth = (v: number) => { dispatch(setGalleryImageMinimumWidth(v)); @@ -168,28 +157,6 @@ const ImageGalleryContent = () => { dispatch(requestCanvasRescale()); }; - useEffect(() => { - if (!resizeObserverRef.current) { - return; - } - const resizeObserver = new ResizeObserver(() => { - if (!resizeObserverRef.current) { - return; - } - - if ( - resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH - ) { - setShouldShouldIconButtons(true); - return; - } - - setShouldShouldIconButtons(false); - }); - resizeObserver.observe(resizeObserverRef.current); - return () => resizeObserver.disconnect(); // clean up - }, []); - useEffect(() => { const { current: root } = rootRef; if (scroller && root) { @@ -210,12 +177,15 @@ const ImageGalleryContent = () => { }, []); const handleEndReached = useCallback(() => { - if (currentCategory === 'results') { - dispatch(receivedGalleryImages()); - } else if (currentCategory === 'uploads') { - dispatch(receivedUploadImages()); - } - }, [dispatch, currentCategory]); + handleLoadMoreImages(); + }, [handleLoadMoreImages]); + + const handleCategoriesChanged = useCallback( + (newCategories: ImageCategory[]) => { + dispatch(imageCategoriesChanged(newCategories)); + }, + [dispatch] + ); return ( { alignItems="center" justifyContent="space-between" > - } + /> + } > - {shouldShouldIconButtons ? ( - <> - } - onClick={() => dispatch(setCurrentCategory('results'))} - /> - } - onClick={() => dispatch(setCurrentCategory('uploads'))} - /> - - ) : ( - <> - dispatch(setCurrentCategory('results'))} - flexGrow={1} - > - {t('gallery.generations')} - - dispatch(setCurrentCategory('uploads'))} - flexGrow={1} - > - {t('gallery.uploads')} - - - )} - + + + {IMAGE_CATEGORIES.map((c) => ( + + {capitalize(c)} + + ))} + + + { )} state, gallerySelector], - (state, gallery) => { - const { selectedImage, currentCategory } = gallery; + [ + (state: RootState) => state, + gallerySelector, + selectFilteredImagesAsObject, + selectFilteredImagesIds, + ], + (state, gallery, filteredImagesAsObject, filteredImageIds) => { + const { selectedImage } = gallery; if (!selectedImage) { return { @@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector( }; } - const currentImageIndex = state[currentCategory].ids.findIndex( + const currentImageIndex = filteredImageIds.findIndex( (i) => i === selectedImage.image_name ); const nextImageIndex = clamp( currentImageIndex + 1, 0, - state[currentCategory].ids.length - 1 + filteredImageIds.length - 1 ); const prevImageIndex = clamp( currentImageIndex - 1, 0, - state[currentCategory].ids.length - 1 + filteredImageIds.length - 1 ); - const nextImageId = state[currentCategory].ids[nextImageIndex]; - const prevImageId = state[currentCategory].ids[prevImageIndex]; + const nextImageId = filteredImageIds[nextImageIndex]; + const prevImageId = filteredImageIds[prevImageIndex]; - const nextImage = state[currentCategory].entities[nextImageId]; - const prevImage = state[currentCategory].entities[prevImageId]; + const nextImage = filteredImagesAsObject[nextImageId]; + const prevImage = filteredImagesAsObject[prevImageId]; - const imagesLength = state[currentCategory].ids.length; + const imagesLength = filteredImageIds.length; return { isOnFirstImage: currentImageIndex === 0, diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts index 1a73971774..89709b322a 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts @@ -1,33 +1,18 @@ -import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { ResourceOrigin } from 'services/api'; -import { selectResultsEntities } from '../store/resultsSlice'; -import { selectUploadsEntities } from '../store/uploadsSlice'; +import { selectImagesEntities } from '../store/imagesSlice'; +import { useCallback } from 'react'; -const useGetImageByNameSelector = createSelector( - [selectResultsEntities, selectUploadsEntities], - (allResults, allUploads) => { - return { allResults, allUploads }; - } -); - -const useGetImageByNameAndOrigin = () => { - const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector); - return (name: string, origin: ResourceOrigin) => { - if (origin === 'internal') { - const resultImagesResult = allResults[name]; - if (resultImagesResult) { - return resultImagesResult; +const useGetImageByName = () => { + const images = useAppSelector(selectImagesEntities); + return useCallback( + (name: string | undefined) => { + if (!name) { + return; } - } - - if (origin === 'external') { - const userImagesResult = allUploads[name]; - if (userImagesResult) { - return userImagesResult; - } - } - }; + return images[name]; + }, + [images] + ); }; -export default useGetImageByNameAndOrigin; +export default useGetImageByName; diff --git a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts index 49f51d5a80..44e03f9f71 100644 --- a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts @@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice'; * Gallery slice persist denylist */ export const galleryPersistDenylist: (keyof GalleryState)[] = [ - 'currentCategory', 'shouldAutoSwitchToNewImages', ]; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index e904620d90..16121b6e38 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,12 +1,6 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { - receivedGalleryImages, - receivedUploadImages, -} from '../../../services/thunks/gallery'; import { ImageDTO } from 'services/api'; -import { resultUpserted } from './resultsSlice'; -import { uploadUpserted } from './uploadsSlice'; type GalleryImageObjectFitType = 'contain' | 'cover'; @@ -16,7 +10,6 @@ export interface GalleryState { galleryImageObjectFit: GalleryImageObjectFitType; shouldAutoSwitchToNewImages: boolean; shouldUseSingleGalleryColumn: boolean; - currentCategory: 'results' | 'uploads'; } export const initialGalleryState: GalleryState = { @@ -24,7 +17,6 @@ export const initialGalleryState: GalleryState = { galleryImageObjectFit: 'cover', shouldAutoSwitchToNewImages: true, shouldUseSingleGalleryColumn: false, - currentCategory: 'results', }; export const gallerySlice = createSlice({ @@ -48,12 +40,6 @@ export const gallerySlice = createSlice({ setShouldAutoSwitchToNewImages: (state, action: PayloadAction) => { state.shouldAutoSwitchToNewImages = action.payload; }, - setCurrentCategory: ( - state, - action: PayloadAction<'results' | 'uploads'> - ) => { - state.currentCategory = action.payload; - }, setShouldUseSingleGalleryColumn: ( state, action: PayloadAction @@ -61,55 +47,6 @@ export const gallerySlice = createSlice({ state.shouldUseSingleGalleryColumn = action.payload; }, }, - extraReducers(builder) { - builder.addCase(receivedGalleryImages.fulfilled, (state, action) => { - // rehydrate selectedImage URL when results list comes in - // solves case when outdated URL is in local storage - const selectedImage = state.selectedImage; - if (selectedImage) { - const selectedImageInResults = action.payload.items.find( - (image) => image.image_name === selectedImage.image_name - ); - - if (selectedImageInResults) { - selectedImage.image_url = selectedImageInResults.image_url; - selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url; - state.selectedImage = selectedImage; - } - } - }); - - builder.addCase(receivedUploadImages.fulfilled, (state, action) => { - // rehydrate selectedImage URL when results list comes in - // solves case when outdated URL is in local storage - const selectedImage = state.selectedImage; - if (selectedImage) { - const selectedImageInResults = action.payload.items.find( - (image) => image.image_name === selectedImage.image_name - ); - - if (selectedImageInResults) { - selectedImage.image_url = selectedImageInResults.image_url; - selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url; - state.selectedImage = selectedImage; - } - } - }); - - builder.addCase(resultUpserted, (state, action) => { - if (state.shouldAutoSwitchToNewImages) { - state.selectedImage = action.payload; - state.currentCategory = 'results'; - } - }); - - builder.addCase(uploadUpserted, (state, action) => { - if (state.shouldAutoSwitchToNewImages) { - state.selectedImage = action.payload; - state.currentCategory = 'uploads'; - } - }); - }, }); export const { @@ -118,7 +55,6 @@ export const { setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, setShouldUseSingleGalleryColumn, - setCurrentCategory, } = gallerySlice.actions; export default gallerySlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts new file mode 100644 index 0000000000..8ab34fccf0 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts @@ -0,0 +1,127 @@ +import { + PayloadAction, + createEntityAdapter, + createSelector, + createSlice, +} from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { ImageCategory, ImageDTO } from 'services/api'; +import { dateComparator } from 'common/util/dateComparator'; +import { isString, keyBy } from 'lodash-es'; +import { receivedPageOfImages } from 'services/thunks/image'; + +export const imagesAdapter = createEntityAdapter({ + selectId: (image) => image.image_name, + sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), +}); + +type AdditionaImagesState = { + offset: number; + limit: number; + total: number; + isLoading: boolean; + categories: ImageCategory[]; +}; + +export const initialImagesState = + imagesAdapter.getInitialState({ + offset: 0, + limit: 0, + total: 0, + isLoading: false, + categories: ['general', 'control', 'mask', 'other', 'user'], + }); + +export type ImagesState = typeof initialImagesState; + +const imagesSlice = createSlice({ + name: 'images', + initialState: initialImagesState, + reducers: { + imageUpserted: (state, action: PayloadAction) => { + imagesAdapter.upsertOne(state, action.payload); + }, + imageRemoved: (state, action: PayloadAction) => { + if (isString(action.payload)) { + imagesAdapter.removeOne(state, action.payload); + return; + } + + imagesAdapter.removeOne(state, action.payload.image_name); + }, + imageCategoriesChanged: (state, action: PayloadAction) => { + state.categories = action.payload; + }, + }, + extraReducers: (builder) => { + builder.addCase(receivedPageOfImages.pending, (state) => { + state.isLoading = true; + }); + builder.addCase(receivedPageOfImages.rejected, (state) => { + state.isLoading = false; + }); + builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { + state.isLoading = false; + const { items, offset, limit, total } = action.payload; + state.offset = offset; + state.limit = limit; + state.total = total; + imagesAdapter.upsertMany(state, items); + }); + }, +}); + +export const { + selectAll: selectImagesAll, + selectById: selectImagesById, + selectEntities: selectImagesEntities, + selectIds: selectImagesIds, + selectTotal: selectImagesTotal, +} = imagesAdapter.getSelectors((state) => state.images); + +export const { imageUpserted, imageRemoved, imageCategoriesChanged } = + imagesSlice.actions; + +export default imagesSlice.reducer; + +export const selectFilteredImagesAsArray = createSelector( + (state: RootState) => state, + (state) => { + const { + images: { categories }, + } = state; + + return selectImagesAll(state).filter((i) => + categories.includes(i.image_category) + ); + } +); + +export const selectFilteredImagesAsObject = createSelector( + (state: RootState) => state, + (state) => { + const { + images: { categories }, + } = state; + + return keyBy( + selectImagesAll(state).filter((i) => + categories.includes(i.image_category) + ), + 'image_name' + ); + } +); + +export const selectFilteredImagesIds = createSelector( + (state: RootState) => state, + (state) => { + const { + images: { categories }, + } = state; + + return selectImagesAll(state) + .filter((i) => categories.includes(i.image_category)) + .map((i) => i.image_name); + } +); diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts deleted file mode 100644 index 1c3d8aaaec..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { ResultsState } from './resultsSlice'; - -/** - * Results slice persist denylist - * - * Currently denylisting results slice entirely, see `serialize.ts` - */ -export const resultsPersistDenylist: (keyof ResultsState)[] = []; diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts deleted file mode 100644 index 5bc7bd14dd..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts +++ /dev/null @@ -1,91 +0,0 @@ -import { - PayloadAction, - createEntityAdapter, - createSlice, -} from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - receivedGalleryImages, - IMAGES_PER_PAGE, -} from 'services/thunks/gallery'; -import { ImageDTO } from 'services/api'; -import { dateComparator } from 'common/util/dateComparator'; - -export type ResultsImageDTO = Omit & { - image_origin: 'results'; -}; - -export const resultsAdapter = createEntityAdapter({ - selectId: (image) => image.image_name, - sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), -}); - -type AdditionalResultsState = { - page: number; - pages: number; - isLoading: boolean; - nextPage: number; - upsertedImageCount: number; -}; - -export const initialResultsState = - resultsAdapter.getInitialState({ - page: 0, - pages: 0, - isLoading: false, - nextPage: 0, - upsertedImageCount: 0, - }); - -export type ResultsState = typeof initialResultsState; - -const resultsSlice = createSlice({ - name: 'results', - initialState: initialResultsState, - reducers: { - resultUpserted: (state, action: PayloadAction) => { - resultsAdapter.upsertOne(state, action.payload); - state.upsertedImageCount += 1; - }, - resultRemoved: (state, action: PayloadAction) => { - resultsAdapter.removeOne(state, action.payload); - }, - }, - extraReducers: (builder) => { - /** - * Received Result Images Page - PENDING - */ - builder.addCase(receivedGalleryImages.pending, (state) => { - state.isLoading = true; - }); - - /** - * Received Result Images Page - FULFILLED - */ - builder.addCase(receivedGalleryImages.fulfilled, (state, action) => { - const { page, pages } = action.payload; - - // We know these will all be of the results type, but it's not represented in the API types - const items = action.payload.items; - - resultsAdapter.setMany(state, items); - - state.page = page; - state.pages = pages; - state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; - state.isLoading = false; - }); - }, -}); - -export const { - selectAll: selectResultsAll, - selectById: selectResultsById, - selectEntities: selectResultsEntities, - selectIds: selectResultsIds, - selectTotal: selectResultsTotal, -} = resultsAdapter.getSelectors((state) => state.results); - -export const { resultUpserted, resultRemoved } = resultsSlice.actions; - -export default resultsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts deleted file mode 100644 index 296e8b2057..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { UploadsState } from './uploadsSlice'; - -/** - * Uploads slice persist denylist - * - * Currently denylisting uploads slice entirely, see `serialize.ts` - */ -export const uploadsPersistDenylist: (keyof UploadsState)[] = []; diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts deleted file mode 100644 index e7620cbc31..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { - PayloadAction, - createEntityAdapter, - createSlice, -} from '@reduxjs/toolkit'; - -import { RootState } from 'app/store/store'; -import { receivedUploadImages, IMAGES_PER_PAGE } from 'services/thunks/gallery'; -import { ImageDTO } from 'services/api'; -import { dateComparator } from 'common/util/dateComparator'; - -export type UploadsImageDTO = Omit< - ImageDTO, - 'image_origin' | 'image_category' -> & { - image_origin: 'external'; - image_category: 'user'; -}; - -export const uploadsAdapter = createEntityAdapter({ - selectId: (image) => image.image_name, - sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), -}); - -type AdditionalUploadsState = { - page: number; - pages: number; - isLoading: boolean; - nextPage: number; - upsertedImageCount: number; -}; - -export const initialUploadsState = - uploadsAdapter.getInitialState({ - page: 0, - pages: 0, - nextPage: 0, - isLoading: false, - upsertedImageCount: 0, - }); - -export type UploadsState = typeof initialUploadsState; - -const uploadsSlice = createSlice({ - name: 'uploads', - initialState: initialUploadsState, - reducers: { - uploadUpserted: (state, action: PayloadAction) => { - uploadsAdapter.upsertOne(state, action.payload); - state.upsertedImageCount += 1; - }, - uploadRemoved: (state, action: PayloadAction) => { - uploadsAdapter.removeOne(state, action.payload); - }, - }, - extraReducers: (builder) => { - /** - * Received Upload Images Page - PENDING - */ - builder.addCase(receivedUploadImages.pending, (state) => { - state.isLoading = true; - }); - - /** - * Received Upload Images Page - FULFILLED - */ - builder.addCase(receivedUploadImages.fulfilled, (state, action) => { - const { page, pages } = action.payload; - - // We know these will all be of the uploads type, but it's not represented in the API types - const items = action.payload.items; - - uploadsAdapter.setMany(state, items); - - state.page = page; - state.pages = pages; - state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; - state.isLoading = false; - }); - }, -}); - -export const { - selectAll: selectUploadsAll, - selectById: selectUploadsById, - selectEntities: selectUploadsEntities, - selectIds: selectUploadsIds, - selectTotal: selectUploadsTotal, -} = uploadsAdapter.getSelectors((state) => state.uploads); - -export const { uploadUpserted, uploadRemoved } = uploadsSlice.actions; - -export default uploadsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index e4a0f41ee1..57cefb0a9c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder'; import { useGetUrl } from 'common/util/getUrl'; -import useGetImageByNameAndOrigin from 'features/gallery/hooks/useGetImageByName'; +import useGetImageByName from 'features/gallery/hooks/useGetImageByName'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { @@ -11,7 +11,6 @@ import { } from 'features/nodes/types/types'; import { DragEvent, memo, useCallback, useState } from 'react'; -import { ResourceOrigin } from 'services/api'; import { FieldComponentProps } from './types'; const ImageInputFieldComponent = ( @@ -19,7 +18,7 @@ const ImageInputFieldComponent = ( ) => { const { nodeId, field } = props; - const getImageByNameAndType = useGetImageByNameAndOrigin(); + const getImageByName = useGetImageByName(); const dispatch = useAppDispatch(); const [url, setUrl] = useState(field.value?.image_url); const { getUrl } = useGetUrl(); @@ -27,15 +26,7 @@ const ImageInputFieldComponent = ( const handleDrop = useCallback( (e: DragEvent) => { const name = e.dataTransfer.getData('invokeai/imageName'); - const type = e.dataTransfer.getData( - 'invokeai/imageOrigin' - ) as ResourceOrigin; - - if (!name || !type) { - return; - } - - const image = getImageByNameAndType(name, type); + const image = getImageByName(name); if (!image) { return; @@ -51,7 +42,7 @@ const ImageInputFieldComponent = ( }) ); }, - [getImageByNameAndType, dispatch, field.name, nodeId] + [getImageByName, dispatch, field.name, nodeId] ); return ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts index 3615f7d298..2e741443cf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts @@ -26,18 +26,21 @@ const buildBaseNode = ( | ImageToImageInvocation | InpaintInvocation | undefined => { - const dimensionsOverride = state.canvas.boundingBoxDimensions; + const overrides = { + ...state.canvas.boundingBoxDimensions, + is_intermediate: true, + }; if (nodeType === 'txt2img') { - return buildTxt2ImgNode(state, dimensionsOverride); + return buildTxt2ImgNode(state, overrides); } if (nodeType === 'img2img') { - return buildImg2ImgNode(state, dimensionsOverride); + return buildImg2ImgNode(state, overrides); } if (nodeType === 'inpaint' || nodeType === 'outpaint') { - return buildInpaintNode(state, dimensionsOverride); + return buildInpaintNode(state, overrides); } }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx index a5b106163f..cfe1513420 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx @@ -5,7 +5,6 @@ import { useGetUrl } from 'common/util/getUrl'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { DragEvent, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { ResourceOrigin } from 'services/api'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { initialImageSelected } from 'features/parameters/store/actions'; @@ -55,11 +54,7 @@ const InitialImagePreview = () => { const handleDrop = useCallback( (e: DragEvent) => { const name = e.dataTransfer.getData('invokeai/imageName'); - const type = e.dataTransfer.getData( - 'invokeai/imageOrigin' - ) as ResourceOrigin; - - dispatch(initialImageSelected({ image_name: name, image_origin: type })); + dispatch(initialImageSelected(name)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts index 27ae63e5dd..ca9826693d 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts @@ -88,7 +88,7 @@ export const useParameters = () => { return; } - dispatch(initialImageSelected(image)); + dispatch(initialImageSelected(image.image_name)); toaster({ title: t('toast.initialImageSet'), status: 'info', diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index 6c1030b7b0..e9b90134e1 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -26,6 +26,6 @@ export const isImageDTO = (image: any): image is ImageDTO => { ); }; -export const initialImageSelected = createAction< - ImageDTO | ImageNameAndOrigin | undefined ->('generation/initialImageSelected'); +export const initialImageSelected = createAction( + 'generation/initialImageSelected' +); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts b/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts index dbf5eec791..b7322740ef 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts @@ -1,34 +1,3 @@ -import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; -import { selectResultsById } from 'features/gallery/store/resultsSlice'; -import { selectUploadsById } from 'features/gallery/store/uploadsSlice'; -import { isEqual } from 'lodash-es'; export const generationSelector = (state: RootState) => state.generation; - -export const mayGenerateMultipleImagesSelector = createSelector( - generationSelector, - ({ shouldRandomizeSeed, shouldGenerateVariations }) => { - return shouldRandomizeSeed || shouldGenerateVariations; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -export const initialImageSelector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - const { initialImage } = generation; - - if (initialImage?.type === 'results') { - return selectResultsById(state, initialImage.name); - } - - if (initialImage?.type === 'uploads') { - return selectUploadsById(state, initialImage.name); - } - } -); diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts deleted file mode 100644 index e6bb163167..0000000000 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ /dev/null @@ -1,64 +0,0 @@ -import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { ImagesService, PaginatedResults_ImageDTO_ } from 'services/api'; - -export const IMAGES_PER_PAGE = 20; - -type ReceivedResultImagesPageThunkConfig = { - rejectValue: { - error: unknown; - }; -}; - -export const receivedGalleryImages = createAppAsyncThunk< - PaginatedResults_ImageDTO_, - void, - ReceivedResultImagesPageThunkConfig ->( - 'results/receivedResultImagesPage', - async (_arg, { getState, rejectWithValue }) => { - const { page, pages, nextPage, upsertedImageCount } = getState().results; - - // If many images have been upserted, we need to offset the page number - // TODO: add an offset param to the list images endpoint - const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); - - const response = await ImagesService.listImagesWithMetadata({ - excludeCategories: ['user'], - isIntermediate: false, - page: nextPage + pageOffset, - perPage: IMAGES_PER_PAGE, - }); - - return response; - } -); - -type ReceivedUploadImagesPageThunkConfig = { - rejectValue: { - error: unknown; - }; -}; - -export const receivedUploadImages = createAppAsyncThunk< - PaginatedResults_ImageDTO_, - void, - ReceivedUploadImagesPageThunkConfig ->( - 'uploads/receivedUploadImagesPage', - async (_arg, { getState, rejectWithValue }) => { - const { page, pages, nextPage, upsertedImageCount } = getState().uploads; - - // If many images have been upserted, we need to offset the page number - // TODO: add an offset param to the list images endpoint - const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); - - const response = await ImagesService.listImagesWithMetadata({ - includeCategories: ['user'], - isIntermediate: false, - page: nextPage + pageOffset, - perPage: IMAGES_PER_PAGE, - }); - - return response; - } -); diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index f324edad2b..87832c6b1e 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -1,5 +1,5 @@ import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { InvokeTabName } from 'features/ui/store/tabMap'; +import { selectImagesAll } from 'features/gallery/store/imagesSlice'; import { ImagesService } from 'services/api'; type imageUrlsReceivedArg = Parameters< @@ -71,3 +71,32 @@ export const imageUpdated = createAppAsyncThunk( return response; } ); + +type ImagesListedArg = Parameters< + (typeof ImagesService)['listImagesWithMetadata'] +>[0]; + +export const IMAGES_PER_PAGE = 20; + +/** + * `ImagesService.listImagesWithMetadata()` thunk + */ +export const receivedPageOfImages = createAppAsyncThunk( + 'api/receivedPageOfImages', + async (_, { getState }) => { + const state = getState(); + const { categories } = state.images; + + const totalImagesInFilter = selectImagesAll(state).filter((i) => + categories.includes(i.image_category) + ).length; + + const response = await ImagesService.listImagesWithMetadata({ + categories, + isIntermediate: false, + offset: totalImagesInFilter, + limit: IMAGES_PER_PAGE, + }); + return response; + } +); diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts index 1231a38b4d..4d33cfa246 100644 --- a/invokeai/frontend/web/src/services/types/guards.ts +++ b/invokeai/frontend/web/src/services/types/guards.ts @@ -1,4 +1,3 @@ -import { UploadsImageDTO } from 'features/gallery/store/uploadsSlice'; import { get, isObject, isString } from 'lodash-es'; import { GraphExecutionState, @@ -10,17 +9,9 @@ import { CollectInvocationOutput, ImageField, LatentsOutput, - ImageDTO, ResourceOrigin, } from 'services/api'; -export const isUploadsImageDTO = ( - image: ImageDTO | undefined -): image is UploadsImageDTO => - image !== undefined && - image.image_origin === 'external' && - image.image_category === 'user'; - export const isImageOutput = ( output: GraphExecutionState['results'][string] ): output is ImageOutput => output.type === 'image_output'; From 3e3dd39ae401bfa2294f348d40c43a7789d90563 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 09:32:12 +1000 Subject: [PATCH 18/34] fix(nodes): fix images service update() for `is_intermediate` --- invokeai/app/services/image_record_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index c27596afac..6907ac3952 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -253,7 +253,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): ) # Change the image's `is_intermediate`` flag - if changes.session_id is not None: + if changes.is_intermediate is not None: self._cursor.execute( f"""--sql UPDATE images From 6f82801d0762d03d153794f71a3f7c767c68d8d6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 09:32:30 +1000 Subject: [PATCH 19/34] fix(ui): fix canvas save to gallery incorrect `is_intermediate` flag --- .../listenerMiddleware/listeners/canvasSavedToGallery.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index a692a90670..59cbf92a77 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -37,7 +37,7 @@ export const addCanvasSavedToGalleryListener = () => { file: new File([blob], filename, { type: 'image/png' }), }, imageCategory: 'general', - isIntermediate: true, + isIntermediate: false, }) ); From 043f9d9ba423d3763095a0b51a8d3893f7ef1836 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 16:03:31 +1000 Subject: [PATCH 20/34] fix(ui): fix auto-switch to new images --- .../web/src/features/gallery/store/gallerySlice.ts | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 16121b6e38..ab62646c0f 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,6 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { ImageDTO } from 'services/api'; +import { imageUpserted } from './imagesSlice'; type GalleryImageObjectFitType = 'contain' | 'cover'; @@ -47,6 +48,13 @@ export const gallerySlice = createSlice({ state.shouldUseSingleGalleryColumn = action.payload; }, }, + extraReducers: (builder) => { + builder.addCase(imageUpserted, (state, action) => { + if (state.shouldAutoSwitchToNewImages) { + state.selectedImage = action.payload; + } + }); + }, }); export const { From 970340cf6252df37d38a4dd974badb93058e3e46 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 16:05:54 +1000 Subject: [PATCH 21/34] fix(ui): infill and scaling options label --- .../Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx index 78a8995bee..ed01da9876 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx @@ -15,7 +15,7 @@ const ParamInfillCollapse = () => { return ( From 6764b2a854bb993d525eecd74a77feddf1400fc9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 16:17:26 +1000 Subject: [PATCH 22/34] fix(ui): fix save to gallery without bounding box --- .../listenerMiddleware/listeners/canvasSavedToGallery.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index 59cbf92a77..b89620775b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -15,7 +15,7 @@ export const addCanvasSavedToGalleryListener = () => { effect: async (action, { dispatch, getState, take }) => { const state = getState(); - const blob = await getBaseLayerBlob(state); + const blob = await getBaseLayerBlob(state, true); if (!blob) { moduleLog.error('Problem getting base layer blob'); From e4705d5ce71df1d9377c75cb4cbae67b28647a90 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 17:18:13 +1000 Subject: [PATCH 23/34] fix(ui): add additional socket event layer to gate handling socket events Some socket events should not be handled by the slice reducers. For example generation progress should not be handled for a canceled session. Added another layer of socket actions. Example: - `socketGeneratorProgress` is dispatched when the actual socket event is received - Listener middleware exclusively handles this event and determines if the application should also handle it - If so, it dispatches `appSocketGeneratorProgress`, which the slices can handle Needed to fix issues related to canceling invocations. --- .../middleware/listenerMiddleware/index.ts | 32 +++-- .../listeners/socketio/socketConnected.ts | 7 +- .../listeners/socketio/socketDisconnected.ts | 9 +- ...Progress.ts => socketGeneratorProgress.ts} | 12 +- ...s => socketGraphExecutionStateComplete.ts} | 11 +- ...omplete.ts => socketInvocationComplete.ts} | 11 +- ...ationError.ts => socketInvocationError.ts} | 10 +- ...nStarted.ts => socketInvocationStarted.ts} | 10 +- .../listeners/socketio/socketSubscribed.ts | 5 +- .../listeners/socketio/socketUnsubscribed.ts | 8 +- .../src/features/system/store/systemSlice.ts | 62 ++++---- .../web/src/services/events/actions.ts | 135 ++++++++++++++++-- .../services/events/util/setEventListeners.ts | 25 ++-- 13 files changed, 249 insertions(+), 88 deletions(-) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/{generatorProgress.ts => socketGeneratorProgress.ts} (65%) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/{graphExecutionStateComplete.ts => socketGraphExecutionStateComplete.ts} (50%) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/{invocationComplete.ts => socketInvocationComplete.ts} (85%) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/{invocationError.ts => socketInvocationError.ts} (58%) rename invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/{invocationStarted.ts => socketInvocationStarted.ts} (70%) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 6cc9867bfd..cf4544e4ea 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -26,15 +26,15 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasMergedListener } from './listeners/canvasMerged'; -import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress'; -import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete'; -import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete'; -import { addInvocationErrorListener } from './listeners/socketio/invocationError'; -import { addInvocationStartedListener } from './listeners/socketio/invocationStarted'; -import { addSocketConnectedListener } from './listeners/socketio/socketConnected'; -import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; -import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; -import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; +import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress'; +import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete'; +import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete'; +import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; +import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; +import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected'; +import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; +import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; +import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke'; import { addImageMetadataReceivedFulfilledListener, @@ -126,7 +126,19 @@ addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); addStagingAreaImageSavedListener(); -// socketio +/** + * Socket.IO Events - these handle SIO events directly and pass on internal application actions. + * We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't + * actually be handled at all. + * + * For example, we don't want to respond to progress events for canceled sessions. To avoid + * duplicating the logic to determine if an event should be responded to, we handle all of that + * "is this session canceled?" logic in these listeners. + * + * The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress` + * action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress` + * action that is handled by reducers in slices. + */ addGeneratorProgressListener(); addGraphExecutionStateCompleteListener(); addInvocationCompleteListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 85035e6bf9..3049d2c933 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,13 +1,13 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { socketConnected } from 'services/events/actions'; +import { appSocketConnected, socketConnected } from 'services/events/actions'; import { receivedPageOfImages } from 'services/thunks/image'; import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addSocketConnectedListener = () => { +export const addSocketConnectedEventListener = () => { startAppListening({ actionCreator: socketConnected, effect: (action, { dispatch, getState }) => { @@ -30,6 +30,9 @@ export const addSocketConnectedListener = () => { if (!nodes.schema && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } + + // pass along the socket event as an application action + dispatch(appSocketConnected(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts index 131c3ba18f..d5e8914cef 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts @@ -1,14 +1,19 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { socketDisconnected } from 'services/events/actions'; +import { + socketDisconnected, + appSocketDisconnected, +} from 'services/events/actions'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addSocketDisconnectedListener = () => { +export const addSocketDisconnectedEventListener = () => { startAppListening({ actionCreator: socketDisconnected, effect: (action, { dispatch, getState }) => { moduleLog.debug(action.payload, 'Disconnected'); + // pass along the socket event as an application action + dispatch(appSocketDisconnected(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts similarity index 65% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index 341b5e46d3..756444d644 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/generatorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -1,12 +1,15 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { generatorProgress } from 'services/events/actions'; +import { + appSocketGeneratorProgress, + socketGeneratorProgress, +} from 'services/events/actions'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addGeneratorProgressListener = () => { +export const addGeneratorProgressEventListener = () => { startAppListening({ - actionCreator: generatorProgress, + actionCreator: socketGeneratorProgress, effect: (action, { dispatch, getState }) => { if ( getState().system.canceledSession === @@ -23,6 +26,9 @@ export const addGeneratorProgressListener = () => { action.payload, `Generator progress (${action.payload.data.node.type})` ); + + // pass along the socket event as an application action + dispatch(appSocketGeneratorProgress(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts similarity index 50% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts index a66a7fb547..7297825e32 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/graphExecutionStateComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts @@ -1,17 +1,22 @@ import { log } from 'app/logging/useLogger'; -import { graphExecutionStateComplete } from 'services/events/actions'; +import { + appSocketGraphExecutionStateComplete, + socketGraphExecutionStateComplete, +} from 'services/events/actions'; import { startAppListening } from '../..'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addGraphExecutionStateCompleteListener = () => { +export const addGraphExecutionStateCompleteEventListener = () => { startAppListening({ - actionCreator: graphExecutionStateComplete, + actionCreator: socketGraphExecutionStateComplete, effect: (action, { dispatch, getState }) => { moduleLog.debug( action.payload, `Session invocation complete (${action.payload.data.graph_execution_state_id})` ); + // pass along the socket event as an application action + dispatch(appSocketGraphExecutionStateComplete(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts similarity index 85% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index fb2056ae35..0b47f7a1be 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -1,7 +1,10 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { invocationComplete } from 'services/events/actions'; +import { + appSocketInvocationComplete, + socketInvocationComplete, +} from 'services/events/actions'; import { imageMetadataReceived } from 'services/thunks/image'; import { sessionCanceled } from 'services/thunks/session'; import { isImageOutput } from 'services/types/guards'; @@ -10,9 +13,9 @@ import { progressImageSet } from 'features/system/store/systemSlice'; const moduleLog = log.child({ namespace: 'socketio' }); const nodeDenylist = ['dataURL_image']; -export const addInvocationCompleteListener = () => { +export const addInvocationCompleteEventListener = () => { startAppListening({ - actionCreator: invocationComplete, + actionCreator: socketInvocationComplete, effect: async (action, { dispatch, getState, take }) => { moduleLog.debug( action.payload, @@ -57,6 +60,8 @@ export const addInvocationCompleteListener = () => { dispatch(progressImageSet(null)); } + // pass along the socket event as an application action + dispatch(appSocketInvocationComplete(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts similarity index 58% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index 3a98af120a..51480bbad4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -1,17 +1,21 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { invocationError } from 'services/events/actions'; +import { + appSocketInvocationError, + socketInvocationError, +} from 'services/events/actions'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addInvocationErrorListener = () => { +export const addInvocationErrorEventListener = () => { startAppListening({ - actionCreator: invocationError, + actionCreator: socketInvocationError, effect: (action, { dispatch, getState }) => { moduleLog.error( action.payload, `Invocation error (${action.payload.data.node.type})` ); + dispatch(appSocketInvocationError(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts similarity index 70% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts index f898c62b23..978be2fef5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/invocationStarted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts @@ -1,12 +1,15 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { invocationStarted } from 'services/events/actions'; +import { + appSocketInvocationStarted, + socketInvocationStarted, +} from 'services/events/actions'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addInvocationStartedListener = () => { +export const addInvocationStartedEventListener = () => { startAppListening({ - actionCreator: invocationStarted, + actionCreator: socketInvocationStarted, effect: (action, { dispatch, getState }) => { if ( getState().system.canceledSession === @@ -23,6 +26,7 @@ export const addInvocationStartedListener = () => { action.payload, `Invocation started (${action.payload.data.node.type})` ); + dispatch(appSocketInvocationStarted(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts index 400f8a1689..871222981d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts @@ -1,10 +1,10 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { socketSubscribed } from 'services/events/actions'; +import { appSocketSubscribed, socketSubscribed } from 'services/events/actions'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addSocketSubscribedListener = () => { +export const addSocketSubscribedEventListener = () => { startAppListening({ actionCreator: socketSubscribed, effect: (action, { dispatch, getState }) => { @@ -12,6 +12,7 @@ export const addSocketSubscribedListener = () => { action.payload, `Subscribed (${action.payload.sessionId}))` ); + dispatch(appSocketSubscribed(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts index af15c55d42..ff85379907 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts @@ -1,10 +1,13 @@ import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; -import { socketUnsubscribed } from 'services/events/actions'; +import { + appSocketUnsubscribed, + socketUnsubscribed, +} from 'services/events/actions'; const moduleLog = log.child({ namespace: 'socketio' }); -export const addSocketUnsubscribedListener = () => { +export const addSocketUnsubscribedEventListener = () => { startAppListening({ actionCreator: socketUnsubscribed, effect: (action, { dispatch, getState }) => { @@ -12,6 +15,7 @@ export const addSocketUnsubscribedListener = () => { action.payload, `Unsubscribed (${action.payload.sessionId})` ); + dispatch(appSocketUnsubscribed(action.payload)); }, }); }; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 403fd60501..65e1161200 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -2,17 +2,6 @@ import { UseToastOptions } from '@chakra-ui/react'; import { PayloadAction, isAnyOf } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import * as InvokeAI from 'app/types/invokeai'; -import { - generatorProgress, - graphExecutionStateComplete, - invocationComplete, - invocationError, - invocationStarted, - socketConnected, - socketDisconnected, - socketSubscribed, - socketUnsubscribed, -} from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { makeToast } from '../../../app/components/Toaster'; @@ -30,6 +19,17 @@ import { t } from 'i18next'; import { userInvoked } from 'app/store/actions'; import { LANGUAGES } from '../components/LanguagePicker'; import { imageUploaded } from 'services/thunks/image'; +import { + appSocketConnected, + appSocketDisconnected, + appSocketGeneratorProgress, + appSocketGraphExecutionStateComplete, + appSocketInvocationComplete, + appSocketInvocationError, + appSocketInvocationStarted, + appSocketSubscribed, + appSocketUnsubscribed, +} from 'services/events/actions'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -227,7 +227,7 @@ export const systemSlice = createSlice({ /** * Socket Subscribed */ - builder.addCase(socketSubscribed, (state, action) => { + builder.addCase(appSocketSubscribed, (state, action) => { state.sessionId = action.payload.sessionId; state.canceledSession = ''; }); @@ -235,14 +235,14 @@ export const systemSlice = createSlice({ /** * Socket Unsubscribed */ - builder.addCase(socketUnsubscribed, (state) => { + builder.addCase(appSocketUnsubscribed, (state) => { state.sessionId = null; }); /** * Socket Connected */ - builder.addCase(socketConnected, (state) => { + builder.addCase(appSocketConnected, (state) => { state.isConnected = true; state.isCancelable = true; state.isProcessing = false; @@ -257,7 +257,7 @@ export const systemSlice = createSlice({ /** * Socket Disconnected */ - builder.addCase(socketDisconnected, (state) => { + builder.addCase(appSocketDisconnected, (state) => { state.isConnected = false; state.isProcessing = false; state.isCancelable = true; @@ -272,7 +272,7 @@ export const systemSlice = createSlice({ /** * Invocation Started */ - builder.addCase(invocationStarted, (state) => { + builder.addCase(appSocketInvocationStarted, (state) => { state.isCancelable = true; state.isProcessing = true; state.currentStatusHasSteps = false; @@ -286,7 +286,7 @@ export const systemSlice = createSlice({ /** * Generator Progress */ - builder.addCase(generatorProgress, (state, action) => { + builder.addCase(appSocketGeneratorProgress, (state, action) => { const { step, total_steps, progress_image } = action.payload.data; state.isProcessing = true; @@ -303,7 +303,7 @@ export const systemSlice = createSlice({ /** * Invocation Complete */ - builder.addCase(invocationComplete, (state, action) => { + builder.addCase(appSocketInvocationComplete, (state, action) => { const { data } = action.payload; // state.currentIteration = 0; @@ -322,7 +322,7 @@ export const systemSlice = createSlice({ /** * Invocation Error */ - builder.addCase(invocationError, (state) => { + builder.addCase(appSocketInvocationError, (state) => { state.isProcessing = false; state.isCancelable = true; // state.currentIteration = 0; @@ -338,6 +338,18 @@ export const systemSlice = createSlice({ ); }); + /** + * Graph Execution State Complete + */ + builder.addCase(appSocketGraphExecutionStateComplete, (state) => { + state.isProcessing = false; + state.isCancelable = false; + state.isCancelScheduled = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusConnected'; + }); + /** * Session Invoked - PENDING */ @@ -367,18 +379,6 @@ export const systemSlice = createSlice({ ); }); - /** - * Session Canceled - */ - builder.addCase(graphExecutionStateComplete, (state) => { - state.isProcessing = false; - state.isCancelable = false; - state.isCancelScheduled = false; - state.currentStep = 0; - state.totalSteps = 0; - state.statusTranslationKey = 'common.statusConnected'; - }); - /** * Received available models from the backend */ diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 76bffeaa49..5832cb24b1 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -12,46 +12,153 @@ type BaseSocketPayload = { timestamp: string; }; -// Create actions for each socket event +// Create actions for each socket // Middleware and redux can then respond to them as needed +/** + * Socket.IO Connected + * + * Do not use. Only for use in middleware. + */ export const socketConnected = createAction( 'socket/socketConnected' ); +/** + * App-level Socket.IO Connected + */ +export const appSocketConnected = createAction( + 'socket/appSocketConnected' +); + +/** + * Socket.IO Disconnect + * + * Do not use. Only for use in middleware. + */ export const socketDisconnected = createAction( 'socket/socketDisconnected' ); +/** + * App-level Socket.IO Disconnected + */ +export const appSocketDisconnected = createAction( + 'socket/appSocketDisconnected' +); + +/** + * Socket.IO Subscribed + * + * Do not use. Only for use in middleware. + */ export const socketSubscribed = createAction< BaseSocketPayload & { sessionId: string } >('socket/socketSubscribed'); +/** + * App-level Socket.IO Subscribed + */ +export const appSocketSubscribed = createAction< + BaseSocketPayload & { sessionId: string } +>('socket/appSocketSubscribed'); + +/** + * Socket.IO Unsubscribed + * + * Do not use. Only for use in middleware. + */ export const socketUnsubscribed = createAction< BaseSocketPayload & { sessionId: string } >('socket/socketUnsubscribed'); -export const invocationStarted = createAction< - BaseSocketPayload & { data: InvocationStartedEvent } ->('socket/invocationStarted'); +/** + * App-level Socket.IO Unsubscribed + */ +export const appSocketUnsubscribed = createAction< + BaseSocketPayload & { sessionId: string } +>('socket/appSocketUnsubscribed'); -export const invocationComplete = createAction< +/** + * Socket.IO Invocation Started + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationStarted = createAction< + BaseSocketPayload & { data: InvocationStartedEvent } +>('socket/socketInvocationStarted'); + +/** + * App-level Socket.IO Invocation Started + */ +export const appSocketInvocationStarted = createAction< + BaseSocketPayload & { data: InvocationStartedEvent } +>('socket/appSocketInvocationStarted'); + +/** + * Socket.IO Invocation Complete + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationComplete = createAction< BaseSocketPayload & { data: InvocationCompleteEvent; } ->('socket/invocationComplete'); +>('socket/socketInvocationComplete'); -export const invocationError = createAction< +/** + * App-level Socket.IO Invocation Complete + */ +export const appSocketInvocationComplete = createAction< + BaseSocketPayload & { + data: InvocationCompleteEvent; + } +>('socket/appSocketInvocationComplete'); + +/** + * Socket.IO Invocation Error + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationError = createAction< BaseSocketPayload & { data: InvocationErrorEvent } ->('socket/invocationError'); +>('socket/socketInvocationError'); -export const graphExecutionStateComplete = createAction< +/** + * App-level Socket.IO Invocation Error + */ +export const appSocketInvocationError = createAction< + BaseSocketPayload & { data: InvocationErrorEvent } +>('socket/appSocketInvocationError'); + +/** + * Socket.IO Graph Execution State Complete + * + * Do not use. Only for use in middleware. + */ +export const socketGraphExecutionStateComplete = createAction< BaseSocketPayload & { data: GraphExecutionStateCompleteEvent } ->('socket/graphExecutionStateComplete'); +>('socket/socketGraphExecutionStateComplete'); -export const generatorProgress = createAction< +/** + * App-level Socket.IO Graph Execution State Complete + */ +export const appSocketGraphExecutionStateComplete = createAction< + BaseSocketPayload & { data: GraphExecutionStateCompleteEvent } +>('socket/appSocketGraphExecutionStateComplete'); + +/** + * Socket.IO Generator Progress + * + * Do not use. Only for use in middleware. + */ +export const socketGeneratorProgress = createAction< BaseSocketPayload & { data: GeneratorProgressEvent } ->('socket/generatorProgress'); +>('socket/socketGeneratorProgress'); -// dispatch this when we need to fully reset the socket connection -export const socketReset = createAction('socket/socketReset'); +/** + * App-level Socket.IO Generator Progress + */ +export const appSocketGeneratorProgress = createAction< + BaseSocketPayload & { data: GeneratorProgressEvent } +>('socket/appSocketGeneratorProgress'); diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 5262b26d1e..2c4cba510a 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -3,11 +3,11 @@ import { AppDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; import { Socket } from 'socket.io-client'; import { - generatorProgress, - graphExecutionStateComplete, - invocationComplete, - invocationError, - invocationStarted, + socketGeneratorProgress, + socketGraphExecutionStateComplete, + socketInvocationComplete, + socketInvocationError, + socketInvocationStarted, socketConnected, socketDisconnected, socketSubscribed, @@ -77,21 +77,21 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation started */ socket.on('invocation_started', (data) => { - dispatch(invocationStarted({ data, timestamp: getTimestamp() })); + dispatch(socketInvocationStarted({ data, timestamp: getTimestamp() })); }); /** * Generator progress */ socket.on('generator_progress', (data) => { - dispatch(generatorProgress({ data, timestamp: getTimestamp() })); + dispatch(socketGeneratorProgress({ data, timestamp: getTimestamp() })); }); /** * Invocation error */ socket.on('invocation_error', (data) => { - dispatch(invocationError({ data, timestamp: getTimestamp() })); + dispatch(socketInvocationError({ data, timestamp: getTimestamp() })); }); /** @@ -99,7 +99,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => { */ socket.on('invocation_complete', (data) => { dispatch( - invocationComplete({ + socketInvocationComplete({ data, timestamp: getTimestamp(), }) @@ -110,6 +110,11 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Graph complete */ socket.on('graph_execution_state_complete', (data) => { - dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() })); + dispatch( + socketGraphExecutionStateComplete({ + data, + timestamp: getTimestamp(), + }) + ); }); }; From bce33ea62e3dfbad45ce4cfc634e2e7229a68406 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 17:26:57 +1000 Subject: [PATCH 24/34] fix(ui): when session is complete, null out progress image This may cause minor gallery jumpiness at the very end of processing, but is necessary to prevent the progress image from sticking around if the last node in a session did not have an image output. --- invokeai/frontend/web/src/features/system/store/systemSlice.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 65e1161200..b35dcd8e29 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -348,6 +348,7 @@ export const systemSlice = createSlice({ state.currentStep = 0; state.totalSteps = 0; state.statusTranslationKey = 'common.statusConnected'; + state.progressImage = null; }); /** From bbb4e8f5ef8268008b7718c5971e76df185146f0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 17:51:21 +1000 Subject: [PATCH 25/34] feat(nodes): add resize image and scale image nodes --- invokeai/app/invocations/image.py | 109 ++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7633bfbc16..d048410468 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -416,6 +416,115 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): ) +PIL_RESAMPLING_MODES = Literal[ + "nearest", + "box", + "bilinear", + "hamming", + "bicubic", + "lanczos", +] + + +PIL_RESAMPLING_MAP = { + "nearest": Image.Resampling.NEAREST, + "box": Image.Resampling.BOX, + "bilinear": Image.Resampling.BILINEAR, + "hamming": Image.Resampling.HAMMING, + "bicubic": Image.Resampling.BICUBIC, + "lanczos": Image.Resampling.LANCZOS, +} + + +class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): + """Resizes an image to specific dimensions""" + + # fmt: off + type: Literal["img_resize"] = "img_resize" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") + resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] + + resize_image = image.resize( + (self.width, self.height), + resample=resample_mode, + ) + + image_dto = context.services.images.create( + image=resize_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, + ), + width=image_dto.width, + height=image_dto.height, + ) + + +class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): + """Scales an image by a factor""" + + # fmt: off + type: Literal["img_scale"] = "img_scale" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the image") + resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] + width = int(image.width * self.scale_factor) + height = int(image.height * self.scale_factor) + + resize_image = image.resize( + (width, height), + resample=resample_mode, + ) + + image_dto = context.services.images.create( + image=resize_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, + ), + width=image_dto.width, + height=image_dto.height, + ) + + class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): """Linear interpolation of all pixels of an image""" From 4aec5d8ffc2639b27895b206211666c1e16d3ab1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 18:26:10 +1000 Subject: [PATCH 26/34] fix(ui): typo --- invokeai/frontend/web/src/features/system/store/systemSlice.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index b35dcd8e29..6bc8d7106a 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -352,7 +352,7 @@ export const systemSlice = createSlice({ }); /** - * Session Invoked - PENDING + * User Invoked */ builder.addCase(userInvoked, (state) => { From 6fe28980b0f2546a3e64ab4354ffe66d61bcaadf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 18:26:49 +1000 Subject: [PATCH 27/34] feat(ui): revert in-gallery progress wasn't fully baked. will revisist in the future. --- .../components/ImageGalleryContent.tsx | 72 ++++++------------- 1 file changed, 23 insertions(+), 49 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 4b1786168d..758c077b54 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -45,14 +45,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import GalleryProgressImage from './GalleryProgressImage'; import { uiSelector } from 'features/ui/store/uiSelectors'; -import { ImageCategory, ImageDTO } from 'services/api'; +import { ImageCategory } from 'services/api'; import { imageCategoriesChanged, selectImagesAll } from '../store/imagesSlice'; import { receivedPageOfImages } from 'services/thunks/image'; import { capitalize } from 'lodash-es'; -const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER'; const IMAGE_CATEGORIES: ImageCategory[] = [ 'general', 'control', @@ -64,13 +62,7 @@ const IMAGE_CATEGORIES: ImageCategory[] = [ const categorySelector = createSelector( [(state: RootState) => state], (state) => { - const { system, images } = state; - const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = []; - - if (system.progressImage) { - tempImages.push(PROGRESS_IMAGE_PLACEHOLDER); - } - + const { images } = state; const { categories } = images; const allImages = selectImagesAll(state); @@ -79,7 +71,7 @@ const categorySelector = createSelector( ); return { - images: tempImages.concat(filteredImages), + images: filteredImages, isLoading: images.isLoading, areMoreImagesAvailable: filteredImages.length < images.total, categories: images.categories, @@ -293,28 +285,17 @@ const ImageGalleryContent = () => { data={images} endReached={handleEndReached} scrollerRef={(ref) => setScrollerRef(ref)} - itemContent={(index, image) => { - const isSelected = - image === PROGRESS_IMAGE_PLACEHOLDER - ? false - : selectedImage?.image_name === image?.image_name; - - return ( - - {image === PROGRESS_IMAGE_PLACEHOLDER ? ( - - ) : ( - - )} - - ); - }} + itemContent={(index, image) => ( + + + + )} /> ) : ( { List: ListContainer, }} scrollerRef={setScroller} - itemContent={(index, image) => { - const isSelected = - image === PROGRESS_IMAGE_PLACEHOLDER - ? false - : selectedImage?.image_name === image?.image_name; - - return image === PROGRESS_IMAGE_PLACEHOLDER ? ( - - ) : ( - - ); - }} + itemContent={(index, image) => ( + + )} /> )} From 4522f3f4c922d26850e9cfebf227113e1ad42450 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 18:40:01 +1000 Subject: [PATCH 28/34] fix(ui): fix progress images in canvas --- .../components/IAICanvasIntermediateImage.tsx | 45 +++++++++++-------- .../components/IAICanvasStagingArea.tsx | 2 +- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx index 745825a975..ea5e9a6486 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx @@ -1,18 +1,24 @@ import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { useGetUrl } from 'common/util/getUrl'; -import { GalleryState } from 'features/gallery/store/gallerySlice'; +import { systemSelector } from 'features/system/store/systemSelectors'; import { ImageConfig } from 'konva/lib/shapes/Image'; import { isEqual } from 'lodash-es'; import { useEffect, useState } from 'react'; import { Image as KonvaImage } from 'react-konva'; +import { canvasSelector } from '../store/canvasSelectors'; const selector = createSelector( - [(state: RootState) => state.gallery], - (gallery: GalleryState) => { - return gallery.intermediateImage ? gallery.intermediateImage : null; + [systemSelector, canvasSelector], + (system, canvas) => { + const { progressImage, sessionId } = system; + const { sessionId: canvasSessionId, boundingBox } = + canvas.layerState.stagingArea; + + return { + boundingBox, + progressImage: sessionId === canvasSessionId ? progressImage : undefined, + }; }, { memoizeOptions: { @@ -25,33 +31,34 @@ type Props = Omit; const IAICanvasIntermediateImage = (props: Props) => { const { ...rest } = props; - const intermediateImage = useAppSelector(selector); - const { getUrl } = useGetUrl(); + const { progressImage, boundingBox } = useAppSelector(selector); const [loadedImageElement, setLoadedImageElement] = useState(null); useEffect(() => { - if (!intermediateImage) return; + if (!progressImage) { + return; + } + const tempImage = new Image(); tempImage.onload = () => { setLoadedImageElement(tempImage); }; - tempImage.src = getUrl(intermediateImage.url); - }, [intermediateImage, getUrl]); - if (!intermediateImage?.boundingBox) return null; + tempImage.src = progressImage.dataURL; + }, [progressImage]); - const { - boundingBox: { x, y, width, height }, - } = intermediateImage; + if (!(progressImage && boundingBox)) { + return null; + } return loadedImageElement ? ( { {shouldShowStagingImage && currentStagingAreaImage && ( From d97438b0b34da45327b6c45a02289a8e51cdb66b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 19:05:22 +1000 Subject: [PATCH 29/34] fix(ui): fix typo in actionsDenylist --- .../web/src/app/store/middleware/devtools/actionsDenylist.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts index 743537d7ea..eb54868735 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts @@ -7,5 +7,6 @@ export const actionsDenylist = [ 'canvas/setBoundingBoxDimensions', 'canvas/setIsDrawing', 'canvas/addPointToCurrentLine', - 'socket/generatorProgress', + 'socket/socketGeneratorProgress', + 'socket/appSocketGeneratorProgress', ]; From a7cebbd970f821d72371cfb036841d848f4bfe65 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 19:06:37 +1000 Subject: [PATCH 30/34] feat(ui): cancel session when staging image accepted --- .../middleware/listenerMiddleware/index.ts | 2 + .../addCommitStagingAreaImageListener.ts | 37 +++++++++++++++++++ .../IAICanvasStagingAreaToolbar.tsx | 21 +++++++++-- .../src/features/canvas/store/canvasSlice.ts | 5 ++- 4 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index cf4544e4ea..26f32252d1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -68,6 +68,7 @@ import { addReceivedPageOfImagesRejectedListener, } from './listeners/receivedPageOfImages'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; +import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; export const listenerMiddleware = createListenerMiddleware(); @@ -125,6 +126,7 @@ addCanvasDownloadedAsImageListener(); addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); addStagingAreaImageSavedListener(); +addCommitStagingAreaImageListener(); /** * Socket.IO Events - these handle SIO events directly and pass on internal application actions. diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts new file mode 100644 index 0000000000..428ecf9c62 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -0,0 +1,37 @@ +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice'; +import { sessionCanceled } from 'services/thunks/session'; + +const moduleLog = log.child({ namespace: 'canvas' }); + +export const addCommitStagingAreaImageListener = () => { + startAppListening({ + actionCreator: commitStagingAreaImage, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + const { sessionId } = state.system; + const canvasSessionId = action.payload; + + if (!canvasSessionId) { + moduleLog.debug('No canvas session, skipping cancel'); + return; + } + + if (canvasSessionId !== sessionId) { + moduleLog.debug( + { + data: { + canvasSessionId, + sessionId, + }, + }, + 'Canvas session does not match global session, skipping cancel' + ); + return; + } + + dispatch(sessionCanceled({ sessionId })); + }, + }); +}; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx index 68bc15bbaa..76ffdcf082 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx @@ -32,7 +32,7 @@ const selector = createSelector( (canvas) => { const { layerState: { - stagingArea: { images, selectedImageIndex }, + stagingArea: { images, selectedImageIndex, sessionId }, }, shouldShowStagingOutline, shouldShowStagingImage, @@ -45,6 +45,7 @@ const selector = createSelector( isOnLastImage: selectedImageIndex === images.length - 1, shouldShowStagingImage, shouldShowStagingOutline, + sessionId, }; }, { @@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => { isOnLastImage, currentStagingAreaImage, shouldShowStagingImage, + sessionId, } = useAppSelector(selector); const { t } = useTranslation(); @@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => { } ); - const handlePrevImage = () => dispatch(prevStagingAreaImage()); - const handleNextImage = () => dispatch(nextStagingAreaImage()); - const handleAccept = () => dispatch(commitStagingAreaImage()); + const handlePrevImage = useCallback( + () => dispatch(prevStagingAreaImage()), + [dispatch] + ); + + const handleNextImage = useCallback( + () => dispatch(nextStagingAreaImage()), + [dispatch] + ); + + const handleAccept = useCallback( + () => dispatch(commitStagingAreaImage(sessionId)), + [dispatch, sessionId] + ); if (!currentStagingAreaImage) return null; diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index 0ebe5b264c..ad0581e42f 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -696,7 +696,10 @@ export const canvasSlice = createSlice({ 0 ); }, - commitStagingAreaImage: (state) => { + commitStagingAreaImage: ( + state, + action: PayloadAction + ) => { if (!state.layerState.stagingArea.images.length) { return; } From 1ddc620192f15edbb0fdbf8974fe37c6f589a7ac Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 19:15:28 +1000 Subject: [PATCH 31/34] feat(ui): only cancel on staging commit if processing --- .../listeners/addCommitStagingAreaImageListener.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index 428ecf9c62..90f71879a1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -10,9 +10,14 @@ export const addCommitStagingAreaImageListener = () => { actionCreator: commitStagingAreaImage, effect: async (action, { dispatch, getState }) => { const state = getState(); - const { sessionId } = state.system; + const { sessionId, isProcessing } = state.system; const canvasSessionId = action.payload; + if (!isProcessing) { + // Only need to cancel if we are processing + return; + } + if (!canvasSessionId) { moduleLog.debug('No canvas session, skipping cancel'); return; From 7004430380919d153839ce2d29f983b6b4ef7082 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 19:54:04 +1000 Subject: [PATCH 32/34] feat(ui): gallery filter dropdown -> Images/Assets toggle --- invokeai/frontend/web/public/locales/en.json | 4 +- .../middleware/listenerMiddleware/index.ts | 6 +- .../listeners/imageCategoriesChanged.ts | 24 ++++++ .../listeners/receivedPageOfImages.ts | 2 +- .../components/ImageGalleryContent.tsx | 76 ++++++++++--------- .../src/features/gallery/store/imagesSlice.ts | 10 ++- .../frontend/web/src/services/thunks/image.ts | 2 + 7 files changed, 86 insertions(+), 38 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 94dff3934a..94ad949023 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -122,7 +122,9 @@ "noImagesInGallery": "No Images In Gallery", "deleteImage": "Delete Image", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.", - "deleteImagePermanent": "Deleted images cannot be restored." + "deleteImagePermanent": "Deleted images cannot be restored.", + "images": "Images", + "assets": "Assets" }, "hotkeys": { "keyboardShortcuts": "Keyboard Shortcuts", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 26f32252d1..ba16e56371 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -69,6 +69,7 @@ import { } from './listeners/receivedPageOfImages'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; +import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; export const listenerMiddleware = createListenerMiddleware(); @@ -166,6 +167,9 @@ addSessionCanceledPendingListener(); addSessionCanceledFulfilledListener(); addSessionCanceledRejectedListener(); -// Images +// Fetching images addReceivedPageOfImagesFulfilledListener(); addReceivedPageOfImagesRejectedListener(); + +// Gallery +addImageCategoriesChangedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts new file mode 100644 index 0000000000..85d56d3913 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts @@ -0,0 +1,24 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { receivedPageOfImages } from 'services/thunks/image'; +import { + imageCategoriesChanged, + selectFilteredImagesAsArray, +} from 'features/gallery/store/imagesSlice'; + +const moduleLog = log.child({ namespace: 'gallery' }); + +export const addImageCategoriesChangedListener = () => { + startAppListening({ + actionCreator: imageCategoriesChanged, + effect: (action, { getState, dispatch }) => { + const filteredImagesCount = selectFilteredImagesAsArray( + getState() + ).length; + + if (!filteredImagesCount) { + dispatch(receivedPageOfImages()); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts index 9a2ec0e7a5..cde7e22e3d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts @@ -11,7 +11,7 @@ export const addReceivedPageOfImagesFulfilledListener = () => { effect: (action, { getState, dispatch }) => { const page = action.payload; moduleLog.debug( - { data: { page } }, + { data: { payload: action.payload } }, `Received ${page.items.length} images` ); }, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 758c077b54..ce7eb00404 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -1,5 +1,6 @@ import { Box, + ButtonGroup, Checkbox, CheckboxGroup, Flex, @@ -36,7 +37,13 @@ import { } from 'react'; import { useTranslation } from 'react-i18next'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; -import { FaFilter, FaWrench } from 'react-icons/fa'; +import { + FaFilter, + FaImage, + FaImages, + FaServer, + FaWrench, +} from 'react-icons/fa'; import { MdPhotoLibrary } from 'react-icons/md'; import HoverableImage from './HoverableImage'; @@ -47,18 +54,15 @@ import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { uiSelector } from 'features/ui/store/uiSelectors'; import { ImageCategory } from 'services/api'; -import { imageCategoriesChanged, selectImagesAll } from '../store/imagesSlice'; +import { + ASSETS_CATEGORIES, + IMAGE_CATEGORIES, + imageCategoriesChanged, + selectImagesAll, +} from '../store/imagesSlice'; import { receivedPageOfImages } from 'services/thunks/image'; import { capitalize } from 'lodash-es'; -const IMAGE_CATEGORIES: ImageCategory[] = [ - 'general', - 'control', - 'mask', - 'user', - 'other', -]; - const categorySelector = createSelector( [(state: RootState) => state], (state) => { @@ -179,6 +183,14 @@ const ImageGalleryContent = () => { [dispatch] ); + const handleClickImagesCategory = useCallback(() => { + dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); + }, [dispatch]); + + const handleClickAssetsCategory = useCallback(() => { + dispatch(imageCategoriesChanged(ASSETS_CATEGORIES)); + }, [dispatch]); + return ( { alignItems="center" justifyContent="space-between" > - } - /> - } - > - - - {IMAGE_CATEGORIES.map((c) => ( - - {capitalize(c)} - - ))} - - - - + + } + /> + } + /> + } /> } diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts index 8ab34fccf0..cb6469aeb4 100644 --- a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts @@ -15,6 +15,14 @@ export const imagesAdapter = createEntityAdapter({ sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), }); +export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; +export const ASSETS_CATEGORIES: ImageCategory[] = [ + 'control', + 'mask', + 'user', + 'other', +]; + type AdditionaImagesState = { offset: number; limit: number; @@ -29,7 +37,7 @@ export const initialImagesState = limit: 0, total: 0, isLoading: false, - categories: ['general', 'control', 'mask', 'other', 'user'], + categories: IMAGE_CATEGORIES, }); export type ImagesState = typeof initialImagesState; diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index 87832c6b1e..4ef492f1c6 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -91,6 +91,8 @@ export const receivedPageOfImages = createAppAsyncThunk( categories.includes(i.image_category) ).length; + console.log(categories); + const response = await ImagesService.listImagesWithMetadata({ categories, isIntermediate: false, From f1c226b1714a2bca7f365af9d1feca357b197b6f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 19:55:43 +1000 Subject: [PATCH 33/34] fix(ui): remove `console.log()` --- invokeai/frontend/web/src/services/thunks/image.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index 4ef492f1c6..87832c6b1e 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -91,8 +91,6 @@ export const receivedPageOfImages = createAppAsyncThunk( categories.includes(i.image_category) ).length; - console.log(categories); - const response = await ImagesService.listImagesWithMetadata({ categories, isIntermediate: false, From 070218aba7c87a89ee15fadc5aa052d13136788c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 29 May 2023 22:19:08 +1000 Subject: [PATCH 34/34] feat(ui): add progress image toggle to current image buttons --- invokeai/frontend/web/public/locales/en.json | 2 +- .../components/CurrentImageButtons.tsx | 54 +++++++++---------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 94ad949023..4bd1e5aab3 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -526,7 +526,7 @@ }, "settings": { "models": "Models", - "displayInProgress": "Display In-Progress Images", + "displayInProgress": "Display Progress Images", "saveSteps": "Save images every n steps", "confirmOnDelete": "Confirm On Delete", "displayHelpIcons": "Display Help Icons", diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx index c19a404a37..dc3022efb2 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx @@ -1,5 +1,5 @@ import { createSelector } from '@reduxjs/toolkit'; -import { isEqual, isString } from 'lodash-es'; +import { isEqual } from 'lodash-es'; import { ButtonGroup, @@ -25,8 +25,8 @@ import { } from 'features/ui/store/uiSelectors'; import { setActiveTab, - setShouldHidePreview, setShouldShowImageDetails, + setShouldShowProgressInViewer, } from 'features/ui/store/uiSlice'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -37,18 +37,14 @@ import { FaDownload, FaExpand, FaExpandArrowsAlt, - FaEye, - FaEyeSlash, FaGrinStars, + FaHourglassHalf, FaQuoteRight, FaSeedling, FaShare, FaShareAlt, - FaTrash, - FaWrench, } from 'react-icons/fa'; import { gallerySelector } from '../store/gallerySelectors'; -import DeleteImageModal from './DeleteImageModal'; import { useCallback } from 'react'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { useGetUrl } from 'common/util/getUrl'; @@ -90,7 +86,11 @@ const currentImageButtonsSelector = createSelector( const { isLightboxOpen } = lightbox; - const { shouldShowImageDetails, shouldHidePreview } = ui; + const { + shouldShowImageDetails, + shouldHidePreview, + shouldShowProgressInViewer, + } = ui; const { selectedImage } = gallery; @@ -112,6 +112,7 @@ const currentImageButtonsSelector = createSelector( seed: selectedImage?.metadata?.seed, prompt: selectedImage?.metadata?.positive_conditioning, negativePrompt: selectedImage?.metadata?.negative_conditioning, + shouldShowProgressInViewer, }; }, { @@ -145,6 +146,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { image, canDeleteImage, shouldConfirmOnDelete, + shouldShowProgressInViewer, } = useAppSelector(currentImageButtonsSelector); const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; @@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { }); }, [toaster, shouldTransformUrls, getUrl, t, image]); - const handlePreviewVisibility = useCallback(() => { - dispatch(setShouldHidePreview(!shouldHidePreview)); - }, [dispatch, shouldHidePreview]); - const handleClickUseAllParameters = useCallback(() => { recallAllParameters(image); }, [image, recallAllParameters]); @@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { } }, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]); + const handleClickProgressImagesToggle = useCallback(() => { + dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer)); + }, [dispatch, shouldShowProgressInViewer]); + useHotkeys('delete', handleInitiateDelete, [ image, shouldConfirmOnDelete, @@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { } /> } @@ -465,21 +468,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { - {/* : } - tooltip={ - !shouldHidePreview - ? t('parameters.hidePreview') - : t('parameters.showPreview') - } - aria-label={ - !shouldHidePreview - ? t('parameters.hidePreview') - : t('parameters.showPreview') - } - isChecked={shouldHidePreview} - onClick={handlePreviewVisibility} - /> */} {isLightboxEnabled && ( } @@ -604,6 +592,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { /> + + } + isChecked={shouldShowProgressInViewer} + onClick={handleClickProgressImagesToggle} + /> + +