Merge branch 'main' into release/make-web-dist-startable

This commit is contained in:
Lincoln Stein 2023-05-29 14:16:10 -04:00 committed by GitHub
commit dc54cbb1fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
123 changed files with 1788 additions and 1628 deletions

View File

@ -5,6 +5,7 @@ import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -65,7 +66,7 @@ class ApiDependencies:
metadata = CoreMetadataService() metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage( latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents") DiskLatentsStorage(f"{output_folder}/latents")
) )
@ -76,6 +77,7 @@ class ApiDependencies:
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )

View File

@ -1,39 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# invokeai: Optional[InvokeAIMetadata] = Field(
# description="The image's InvokeAI-specific metadata"
# )
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@ -6,8 +6,9 @@ from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
) )
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageDTO, ImageDTO,
ImageRecordChanges, ImageRecordChanges,
@ -34,12 +35,8 @@ async def upload_image(
file: UploadFile, file: UploadFile,
request: Request, request: Request,
response: Response, response: Response,
image_category: ImageCategory = Query( image_category: ImageCategory = Query(description="The category of the image"),
default=ImageCategory.GENERAL, description="The category of the image" is_intermediate: bool = Query(description="Whether this is an intermediate image"),
),
is_intermediate: bool = Query(
default=False, description="Whether this is an intermediate image"
),
session_id: Optional[str] = Query( session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any" default=None, description="The session ID associated with this upload, if any"
), ),
@ -59,7 +56,7 @@ async def upload_image(
try: try:
image_dto = ApiDependencies.invoker.services.images.create( image_dto = ApiDependencies.invoker.services.images.create(
image=pil_image, image=pil_image,
image_type=ImageType.UPLOAD, image_origin=ResourceOrigin.EXTERNAL,
image_category=image_category, image_category=image_category,
session_id=session_id, session_id=session_id,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
@ -73,27 +70,27 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"), image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
try: try:
ApiDependencies.invoker.services.images.delete(image_type, image_name) ApiDependencies.invoker.services.images.delete(image_origin, image_name)
except Exception as e: except Exception as e:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@images_router.patch( @images_router.patch(
"/{image_type}/{image_name}", "/{image_origin}/{image_name}",
operation_id="update_image", operation_id="update_image",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def update_image( async def update_image(
image_type: ImageType = Path(description="The type of image to update"), image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"), image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body( image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image" description="The changes to apply to the image"
@ -103,31 +100,31 @@ async def update_image(
try: try:
return ApiDependencies.invoker.services.images.update( return ApiDependencies.invoker.services.images.update(
image_type, image_name, image_changes image_origin, image_name, image_changes
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image") raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/metadata", "/{image_origin}/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_metadata(
image_type: ImageType = Path(description="The type of image to get"), image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name) return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_type}/{image_name}", "/{image_origin}/{image_name}",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -139,7 +136,7 @@ async def get_image_metadata(
}, },
) )
async def get_image_full( async def get_image_full(
image_type: ImageType = Path( image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get" description="The type of full-resolution image file to get"
), ),
image_name: str = Path(description="The name of full-resolution image file to get"), image_name: str = Path(description="The name of full-resolution image file to get"),
@ -147,7 +144,7 @@ async def get_image_full(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name) path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -163,7 +160,7 @@ async def get_image_full(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/thumbnail", "/{image_origin}/{image_name}/thumbnail",
operation_id="get_image_thumbnail", operation_id="get_image_thumbnail",
response_class=Response, response_class=Response,
responses={ responses={
@ -175,14 +172,14 @@ async def get_image_full(
}, },
) )
async def get_image_thumbnail( async def get_image_thumbnail(
image_type: ImageType = Path(description="The type of thumbnail image file to get"), image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True image_origin, image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -195,25 +192,25 @@ async def get_image_thumbnail(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/urls", "/{image_origin}/{image_name}/urls",
operation_id="get_image_urls", operation_id="get_image_urls",
response_model=ImageUrlsDTO, response_model=ImageUrlsDTO,
) )
async def get_image_urls( async def get_image_urls(
image_type: ImageType = Path(description="The type of the image whose URL to get"), image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO: ) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL""" """Gets an image and thumbnail URL"""
try: try:
image_url = ApiDependencies.invoker.services.images.get_url( image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name image_origin, image_name
) )
thumbnail_url = ApiDependencies.invoker.services.images.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True image_origin, image_name, thumbnail=True
) )
return ImageUrlsDTO( return ImageUrlsDTO(
image_type=image_type, image_origin=image_origin,
image_name=image_name, image_name=image_name,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -225,23 +222,29 @@ async def get_image_urls(
@images_router.get( @images_router.get(
"/", "/",
operation_id="list_images_with_metadata", operation_id="list_images_with_metadata",
response_model=PaginatedResults[ImageDTO], response_model=OffsetPaginatedResults[ImageDTO],
) )
async def list_images_with_metadata( async def list_images_with_metadata(
image_type: ImageType = Query(description="The type of images to list"), image_origin: Optional[ResourceOrigin] = Query(
image_category: ImageCategory = Query(description="The kind of images to list"), default=None, description="The origin of images to list"
page: int = Query(default=0, description="The page of image metadata to get"),
per_page: int = Query(
default=10, description="The number of image metadata per page"
), ),
) -> PaginatedResults[ImageDTO]: categories: Optional[list[ImageCategory]] = Query(
"""Gets a list of images with metadata""" default=None, description="The categories of image to include"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images"""
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
image_type, offset,
image_category, limit,
page, image_origin,
per_page, categories,
is_intermediate,
) )
return image_dtos return image_dtos

View File

@ -16,6 +16,7 @@ from pydantic.fields import Field
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
@ -229,6 +230,7 @@ def invoke_cli():
metadata = CoreMetadataService() metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
images = ImageService( images = ImageService(
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
@ -236,6 +238,7 @@ def invoke_cli():
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )

View File

@ -7,7 +7,7 @@ from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType, ImageCategory from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -163,7 +163,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image( raw_image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
@ -177,8 +177,8 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=processed_image, image=processed_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate is_intermediate=self.is_intermediate
@ -187,7 +187,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField( processed_image_field = ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
) )
return ImageOutput( return ImageOutput(
image=processed_image_field, image=processed_image_field,

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
mask = context.services.images.get_pil_image( mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name self.mask.image_origin, self.mask.image_name
) )
# Convert to cv image/mask # Convert to cv image/mask
@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_inpainted, image=image_inpainted,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -10,9 +10,9 @@ import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
@ -86,8 +86,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image) # loading controlnet image (currently requires pre-processed image)
control_image = ( control_image = (
None if self.control_image is None None if self.control_image is None
else context.services.images.get( else context.services.images.get_pil_image(
self.control_image.image_type, self.control_image.image_name self.control_image.image_origin, self.control_image.image_name
) )
) )
# loading controlnet model # loading controlnet model
@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generate_output.image, image=generate_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation):
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
) )
@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation):
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
) )
# Handle invalid model parameter # Handle invalid model parameter
@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageCategory, ImageField, ImageType from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
) )
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image_type=self.image.image_type, image_origin=self.image.image_origin,
), ),
width=image.width, width=image.width,
height=image.height, height=image.height,
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
if image: if image:
image.show() image.show()
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image_type=self.image.image_type, image_origin=self.image.image_origin,
), ),
width=image.width, width=image.width,
height=image.height, height=image.height,
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_crop = Image.new( image_crop = Image.new(
@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_crop, image=image_crop,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image( base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name self.base_image.image_origin, self.base_image.image_name
) )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else ImageOps.invert( else ImageOps.invert(
context.services.images.get_pil_image( context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name self.mask.image_origin, self.mask.image_name
) )
) )
) )
@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=new_image, image=new_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_mask = image.split()[-1] image_mask = image.split()[-1]
@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_mask, image=image_mask,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK, image_category=ImageCategory.MASK,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
return MaskOutput( return MaskOutput(
mask=ImageField( mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image( image1 = context.services.images.get_pil_image(
self.image1.image_type, self.image1.image_name self.image1.image_origin, self.image1.image_name
) )
image2 = context.services.images.get_pil_image( image2 = context.services.images.get_pil_image(
self.image2.image_type, self.image2.image_name self.image2.image_origin, self.image2.image_name
) )
multiply_image = ImageChops.multiply(image1, image2) multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=multiply_image, image=multiply_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
channel_image = image.getchannel(self.channel) channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=channel_image, image=channel_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
converted_image = image.convert(self.mode) converted_image = image.convert(self.mode)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=converted_image, image=converted_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
blur = ( blur = (
@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=blur_image, image=blur_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -409,7 +409,116 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
"bilinear",
"hamming",
"bicubic",
"lanczos",
]
PIL_RESAMPLING_MAP = {
"nearest": Image.Resampling.NEAREST,
"box": Image.Resampling.BOX,
"bilinear": Image.Resampling.BILINEAR,
"hamming": Image.Resampling.HAMMING,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
"""Resizes an image to specific dimensions"""
# fmt: off
type: Literal["img_resize"] = "img_resize"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
resize_image = image.resize(
(self.width, self.height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
"""Scales an image by a factor"""
# fmt: off
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
height = int(image.height * self.scale_factor)
resize_image = image.resize(
(width, height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -430,7 +539,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@ -440,7 +549,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=lerp_image, image=lerp_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -450,7 +559,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -471,7 +580,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
@ -486,7 +595,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=ilerp_image, image=ilerp_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -496,7 +605,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ImageType from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationContext, InvocationContext,
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
infilled = tile_fill_missing( infilled = tile_fill_missing(
@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_file_storage import ImageType from ..services.image_file_storage import ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from .compel import ConditioningField from .compel import ConditioningField
@ -297,7 +297,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_type, input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name) control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation):
# and gnenerate unique image_name # and gnenerate unique image_name
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image, image=image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# ) # )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
# TODO: this only really needs the vae # TODO: this only really needs the vae

View File

@ -2,7 +2,7 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -4,7 +4,7 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -5,30 +5,52 @@ from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum): class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The type of an image.""" """The origin of a resource (eg image).
RESULT = "results" - INTERNAL: The resource was created by the application.
UPLOAD = "uploads" - EXTERNAL: The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
INTERNAL = "internal"
"""The resource was created by the application."""
EXTERNAL = "external"
"""The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
class InvalidImageTypeException(ValueError): class InvalidOriginException(ValueError):
"""Raised when a provided value is not a valid ImageType. """Raised when a provided value is not a valid ResourceOrigin.
Subclasses `ValueError`. Subclasses `ValueError`.
""" """
def __init__(self, message="Invalid image type."): def __init__(self, message="Invalid resource origin."):
super().__init__(message) super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum): class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories.""" """The category of an image.
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
- MASK: The image is a mask image.
- CONTROL: The image is a ControlNet control image.
- USER: The image is a user-provide image.
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
"""
GENERAL = "general" GENERAL = "general"
CONTROL = "control" """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
MASK = "mask" MASK = "mask"
"""MASK: The image is a mask image."""
CONTROL = "control"
"""CONTROL: The image is a ControlNet control image."""
USER = "user"
"""USER: The image is a user-provide image."""
OTHER = "other" OTHER = "other"
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
class InvalidImageCategoryException(ValueError): class InvalidImageCategoryException(ValueError):
@ -44,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel): class ImageField(BaseModel):
"""An image field used for passing image objects between invocations""" """An image field used for passing image objects between invocations"""
image_type: ImageType = Field( image_origin: ResourceOrigin = Field(
default=ImageType.RESULT, description="The type of the image" default=ResourceOrigin.INTERNAL, description="The type of the image"
) )
image_name: Optional[str] = Field(default=None, description="The name of the image") image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config: class Config:
schema_extra = {"required": ["image_type", "image_name"]} schema_extra = {"required": ["image_origin", "image_name"]}
class ColorField(BaseModel): class ColorField(BaseModel):
@ -61,3 +83,11 @@ class ColorField(BaseModel):
def tuple(self) -> Tuple[int, int, int, int]: def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a) return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional from typing import Any
from invokeai.app.api.models.images import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp

View File

@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
@abstractmethod @abstractmethod
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the internal path to an image or thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass pass
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True) Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath? # TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType: for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_type)).mkdir( Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_origin, image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_origin, image_name)
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG") image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path) thumbnail_image.save(thumbnail_path)
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e: except Exception as e:
raise ImageFileSaveException from e raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try: try:
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename) image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path): if os.path.exists(image_path):
send2trash(image_path) send2trash(image_path)
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path] del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True) thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
if os.path.exists(thumbnail_path): if os.path.exists(thumbnail_path):
send2trash(thumbnail_path) send2trash(thumbnail_path)
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
# strip out any relative path shenanigans # strip out any relative path shenanigans
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail: if thumbnail:
thumbnail_name = get_thumbnail_name(basename) thumbnail_name = get_thumbnail_name(basename)
path = os.path.join( path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name self.__output_folder, image_origin, "thumbnails", thumbnail_name
) )
else: else:
path = os.path.join(self.__output_folder, image_type, basename) path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path) abspath = os.path.abspath(path)

View File

@ -1,21 +1,35 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Optional, cast from typing import Generic, Optional, TypeVar, cast
import sqlite3 import sqlite3
import threading import threading
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageRecordChanges, ImageRecordChanges,
deserialize_image_record, deserialize_image_record,
) )
from invokeai.app.services.item_storage import PaginatedResults
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
# fmt: off
items: list[T] = Field(description="Items")
offset: int = Field(description="Offset from which to retrieve items")
limit: int = Field(description="Limit of items to get")
total: int = Field(description="Total number of items in result")
# fmt: on
# TODO: Should these excpetions subclass existing python exceptions? # TODO: Should these excpetions subclass existing python exceptions?
@ -46,7 +60,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method # TODO: Implement an `update()` method
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord: def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@ -54,7 +68,7 @@ class ImageRecordStorageBase(ABC):
def update( def update(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
"""Updates an image record.""" """Updates an image record."""
@ -63,18 +77,19 @@ class ImageRecordStorageBase(ABC):
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageRecord]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass
# TODO: The database has a nullable `deleted_at` column, currently unused. # TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage. # Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image record.""" """Deletes an image record."""
pass pass
@ -82,7 +97,7 @@ class ImageRecordStorageBase(ABC):
def save( def save(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
width: int, width: int,
height: int, height: int,
@ -103,7 +118,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
@ -129,7 +143,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE TABLE IF NOT EXISTS images ( CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY, image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility -- This is an enum in python, unrestricted string here for flexibility
image_type TEXT NOT NULL, image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility -- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL, image_category TEXT NOT NULL,
width INTEGER NOT NULL, width INTEGER NOT NULL,
@ -138,9 +152,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id TEXT, node_id TEXT,
metadata TEXT, metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE, is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger -- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused -- Soft delete, currently unused
deleted_at DATETIME deleted_at DATETIME
); );
@ -155,7 +169,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
) )
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
""" """
) )
self._cursor.execute( self._cursor.execute(
@ -182,7 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -209,7 +225,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update( def update(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
try: try:
@ -235,6 +251,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""", """,
(changes.session_id, image_name), (changes.session_id, image_name),
) )
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -244,36 +271,61 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageRecord]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( # Manually build two queries - one for the count, one for the records
f"""--sql
SELECT * FROM images
WHERE image_type = ? AND image_category = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(image_type.value, image_category.value, per_page, page * per_page),
)
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
query_conditions = ""
query_params = []
if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n"""
query_params.append(image_origin.value)
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n"""
query_params.append(is_intermediate)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result)) images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute( # Set up and execute the count query, without pagination
"""--sql count_query += query_conditions + ";"
SELECT count(*) FROM images count_params = query_params.copy()
WHERE image_type = ? AND image_category = ? self._cursor.execute(count_query, count_params)
""",
(image_type.value, image_category.value),
)
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -281,13 +333,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
pageCount = int(count / per_page) + 1 return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
return PaginatedResults(
items=images, page=page, pages=pageCount, per_page=per_page, total=count
) )
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -307,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def save( def save(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
session_id: Optional[str], session_id: Optional[str],
width: int, width: int,
@ -325,7 +375,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""--sql """--sql
INSERT OR IGNORE INTO images ( INSERT OR IGNORE INTO images (
image_name, image_name,
image_type, image_origin,
image_category, image_category,
width, width,
height, height,
@ -338,7 +388,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""", """,
( (
image_name, image_name,
image_type.value, image_origin.value,
image_category.value, image_category.value,
width, width,
height, height,

View File

@ -1,14 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
InvalidImageCategoryException, InvalidImageCategoryException,
InvalidImageTypeException, InvalidOriginException,
) )
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
@ -16,6 +15,7 @@ from invokeai.app.services.image_record_storage import (
ImageRecordNotFoundException, ImageRecordNotFoundException,
ImageRecordSaveException, ImageRecordSaveException,
ImageRecordStorageBase, ImageRecordStorageBase,
OffsetPaginatedResults,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
@ -31,6 +31,7 @@ from invokeai.app.services.image_file_storage import (
) )
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
if TYPE_CHECKING: if TYPE_CHECKING:
@ -44,7 +45,7 @@ class ImageServiceABC(ABC):
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
@ -56,7 +57,7 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def update( def update(
self, self,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
@ -64,22 +65,22 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
pass pass
@abstractmethod @abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod @abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str: def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -90,7 +91,7 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets an image's or thumbnail's URL.""" """Gets an image's or thumbnail's URL."""
pass pass
@ -98,16 +99,17 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageDTO]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs.""" """Gets a paginated list of image DTOs."""
pass pass
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_origin: ResourceOrigin, image_name: str):
"""Deletes an image.""" """Deletes an image."""
pass pass
@ -120,6 +122,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"] graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__( def __init__(
@ -129,6 +132,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self.records = image_record_storage self.records = image_record_storage
@ -136,6 +140,7 @@ class ImageServiceDependencies:
self.metadata = metadata self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
@ -149,6 +154,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self._services = ImageServiceDependencies( self._services = ImageServiceDependencies(
@ -157,30 +163,26 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
url=url, url=url,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
if image_type not in ImageType: if image_origin not in ResourceOrigin:
raise InvalidImageTypeException raise InvalidOriginException
if image_category not in ImageCategory: if image_category not in ImageCategory:
raise InvalidImageCategoryException raise InvalidImageCategoryException
image_name = self._create_image_name( image_name = self._services.names.create_image_name()
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
metadata = self._get_metadata(session_id, node_id) metadata = self._get_metadata(session_id, node_id)
@ -191,7 +193,7 @@ class ImageService(ImageServiceABC):
created_at = self._services.records.save( created_at = self._services.records.save(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
@ -204,21 +206,21 @@ class ImageService(ImageServiceABC):
) )
self._services.files.save( self._services.files.save(
image_type=image_type, image_origin=image_origin,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_type, image_name) image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url( thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True image_origin, image_name, True
) )
return ImageDTO( return ImageDTO(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
@ -247,13 +249,13 @@ class ImageService(ImageServiceABC):
def update( def update(
self, self,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.records.update(image_name, image_type, changes) self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_type, image_name) return self.get_dto(image_origin, image_name)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
raise raise
@ -261,10 +263,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record") self._services.logger.error("Problem updating image record")
raise e raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_origin, image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -272,9 +273,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file") self._services.logger.error("Problem getting image file")
raise e raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_type, image_name) return self._services.records.get(image_origin, image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -282,14 +283,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record") self._services.logger.error("Problem getting image record")
raise e raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_type, image_name) image_record = self._services.records.get(image_origin, image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_type, image_name), self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_type, image_name, True), self._services.urls.get_image_url(image_origin, image_name, True),
) )
return image_dto return image_dto
@ -301,10 +302,10 @@ class ImageService(ImageServiceABC):
raise e raise e
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
try: try:
return self._services.files.get_path(image_type, image_name, thumbnail) return self._services.files.get_path(image_origin, image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -317,57 +318,58 @@ class ImageService(ImageServiceABC):
raise e raise e
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
try: try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail) return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageDTO]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self._services.records.get_many( results = self._services.records.get_many(
image_type, offset,
image_category, limit,
page, image_origin,
per_page, categories,
is_intermediate,
) )
image_dtos = list( image_dtos = list(
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, r,
self._services.urls.get_image_url(image_type, r.image_name), self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url( self._services.urls.get_image_url(
image_type, r.image_name, True r.image_origin, r.image_name, True
), ),
), ),
results.items, results.items,
) )
) )
return PaginatedResults[ImageDTO]( return OffsetPaginatedResults[ImageDTO](
items=image_dtos, items=image_dtos,
page=results.page, offset=results.offset,
pages=results.pages, limit=results.limit,
per_page=results.per_page,
total=results.total, total=results.total,
) )
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs") self._services.logger.error("Problem getting paginated image DTOs")
raise e raise e
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_origin: ResourceOrigin, image_name: str):
try: try:
self._services.files.delete(image_type, image_name) self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_type, image_name) self._services.records.delete(image_origin, image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise
@ -378,21 +380,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem deleting image record and file") self._services.logger.error("Problem deleting image record and file")
raise e raise e
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata( def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]: ) -> Union[ImageMetadata, None]:

View File

@ -1,7 +1,7 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictStr from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The type of the image.""" """The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.") image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image.""" """The category of the image."""
width: int = Field(description="The width of the image in px.") width: int = Field(description="The width of the image in px.")
@ -56,6 +56,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
Only limited changes are valid: Only limited changes are valid:
- `image_category`: change the category of an image - `image_category`: change the category of an image
- `session_id`: change the session associated with an image - `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag
""" """
image_category: Optional[ImageCategory] = Field( image_category: Optional[ImageCategory] = Field(
@ -67,6 +68,10 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
description="The image's new session ID.", description="The image's new session ID.",
) )
"""The image's new session ID.""" """The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModel): class ImageUrlsDTO(BaseModel):
@ -74,8 +79,8 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The type of the image.""" """The origin of the image."""
image_url: str = Field(description="The URL of the image.") image_url: str = Field(description="The URL of the image.")
"""The URL of the image.""" """The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.") thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
@ -105,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
)
image_category = ImageCategory( image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value) image_dict.get("image_category", ImageCategory.GENERAL.value)
) )
@ -127,7 +134,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,

View File

@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
import uuid
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
IMAGE = "image"
LATENT = "latent"
class NameServiceBase(ABC):
"""Low-level service responsible for naming resources (images, latents, etc)."""
# TODO: Add customizable naming schemes
@abstractmethod
def create_image_name(self) -> str:
"""Creates a name for an image."""
pass
class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs."""
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename

View File

@ -1,7 +1,7 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name from invokeai.app.util.thumbnails import get_thumbnail_name
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
@abstractmethod @abstractmethod
def get_image_url( def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the URL for an image or thumbnail.""" """Gets the URL for an image or thumbnail."""
pass pass
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
self._base_url = base_url self._base_url = base_url
def get_image_url( def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail: if thumbnail:
return ( return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
) )
return f"{self._base_url}/images/{image_type.value}/{image_basename}" return f"{self._base_url}/images/{image_origin.value}/{image_basename}"

View File

@ -1,5 +1,5 @@
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator

View File

@ -122,7 +122,9 @@
"noImagesInGallery": "No Images In Gallery", "noImagesInGallery": "No Images In Gallery",
"deleteImage": "Delete Image", "deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored." "deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts", "keyboardShortcuts": "Keyboard Shortcuts",
@ -524,7 +526,7 @@
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",
"displayInProgress": "Display In-Progress Images", "displayInProgress": "Display Progress Images",
"saveSteps": "Save images every n steps", "saveSteps": "Save images every n steps",
"confirmOnDelete": "Confirm On Delete", "confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons", "displayHelpIcons": "Display Help Icons",

View File

@ -1,7 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist'; import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
@ -22,11 +20,9 @@ const serializationDenylist: {
models: modelsPersistDenylist, models: modelsPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
results: resultsPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,
// config: configPersistDenyList, // config: configPersistDenyList,
ui: uiPersistDenylist, ui: uiPersistDenylist,
uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist, // hotkeys: hotkeysPersistDenylist,
}; };

View File

@ -1,7 +1,6 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialResultsState } from 'features/gallery/store/resultsSlice'; import { initialImagesState } from 'features/gallery/store/imagesSlice';
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
@ -24,12 +23,11 @@ const initialStates: {
models: initialModelsState, models: initialModelsState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
results: initialResultsState,
system: initialSystemState, system: initialSystemState,
config: initialConfigState, config: initialConfigState,
ui: initialUIState, ui: initialUIState,
uploads: initialUploadsState,
hotkeys: initialHotkeysState, hotkeys: initialHotkeysState,
images: initialImagesState,
}; };
export const unserialize: UnserializeFunction = (data, key) => { export const unserialize: UnserializeFunction = (data, key) => {

View File

@ -7,5 +7,6 @@ export const actionsDenylist = [
'canvas/setBoundingBoxDimensions', 'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing', 'canvas/setIsDrawing',
'canvas/addPointToCurrentLine', 'canvas/addPointToCurrentLine',
'socket/generatorProgress', 'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
]; ];

View File

@ -26,15 +26,15 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress'; import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete'; import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete'; import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorListener } from './listeners/socketio/invocationError'; import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationStartedListener } from './listeners/socketio/invocationStarted'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addSocketConnectedListener } from './listeners/socketio/socketConnected'; import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke'; import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
import { import {
addImageMetadataReceivedFulfilledListener, addImageMetadataReceivedFulfilledListener,
@ -60,13 +60,16 @@ import {
addSessionCanceledRejectedListener, addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled'; } from './listeners/sessionCanceled';
import { import {
addReceivedResultImagesPageFulfilledListener, addImageUpdatedFulfilledListener,
addReceivedResultImagesPageRejectedListener, addImageUpdatedRejectedListener,
} from './listeners/receivedResultImagesPage'; } from './listeners/imageUpdated';
import { import {
addReceivedUploadImagesPageFulfilledListener, addReceivedPageOfImagesFulfilledListener,
addReceivedUploadImagesPageRejectedListener, addReceivedPageOfImagesRejectedListener,
} from './listeners/receivedUploadImagesPage'; } from './listeners/receivedPageOfImages';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -90,6 +93,11 @@ export type AppListenerEffect = ListenerEffect<
addImageUploadedFulfilledListener(); addImageUploadedFulfilledListener();
addImageUploadedRejectedListener(); addImageUploadedRejectedListener();
// Image updated
addImageUpdatedFulfilledListener();
addImageUpdatedRejectedListener();
// Image selected
addInitialImageSelectedListener(); addInitialImageSelectedListener();
// Image deleted // Image deleted
@ -118,8 +126,22 @@ addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener(); addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener(); addCanvasCopiedToClipboardListener();
addCanvasMergedListener(); addCanvasMergedListener();
addStagingAreaImageSavedListener();
addCommitStagingAreaImageListener();
// socketio /**
* Socket.IO Events - these handle SIO events directly and pass on internal application actions.
* We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't
* actually be handled at all.
*
* For example, we don't want to respond to progress events for canceled sessions. To avoid
* duplicating the logic to determine if an event should be responded to, we handle all of that
* "is this session canceled?" logic in these listeners.
*
* The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress`
* action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress`
* action that is handled by reducers in slices.
*/
addGeneratorProgressListener(); addGeneratorProgressListener();
addGraphExecutionStateCompleteListener(); addGraphExecutionStateCompleteListener();
addInvocationCompleteListener(); addInvocationCompleteListener();
@ -145,8 +167,9 @@ addSessionCanceledPendingListener();
addSessionCanceledFulfilledListener(); addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener(); addSessionCanceledRejectedListener();
// Gallery pages // Fetching images
addReceivedResultImagesPageFulfilledListener(); addReceivedPageOfImagesFulfilledListener();
addReceivedResultImagesPageRejectedListener(); addReceivedPageOfImagesRejectedListener();
addReceivedUploadImagesPageFulfilledListener();
addReceivedUploadImagesPageRejectedListener(); // Gallery
addImageCategoriesChangedListener();

View File

@ -0,0 +1,42 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
import { sessionCanceled } from 'services/thunks/session';
const moduleLog = log.child({ namespace: 'canvas' });
export const addCommitStagingAreaImageListener = () => {
startAppListening({
actionCreator: commitStagingAreaImage,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const { sessionId, isProcessing } = state.system;
const canvasSessionId = action.payload;
if (!isProcessing) {
// Only need to cancel if we are processing
return;
}
if (!canvasSessionId) {
moduleLog.debug('No canvas session, skipping cancel');
return;
}
if (canvasSessionId !== sessionId) {
moduleLog.debug(
{
data: {
canvasSessionId,
sessionId,
},
},
'Canvas session does not match global session, skipping cancel'
);
return;
}
dispatch(sessionCanceled({ sessionId }));
},
});
};

View File

@ -55,6 +55,8 @@ export const addCanvasMergedListener = () => {
formData: { formData: {
file: new File([blob], filename, { type: 'image/png' }), file: new File([blob], filename, { type: 'image/png' }),
}, },
imageCategory: 'general',
isIntermediate: true,
}) })
); );

View File

@ -4,16 +4,18 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { imageUpserted } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => { export const addCanvasSavedToGalleryListener = () => {
startAppListening({ startAppListening({
actionCreator: canvasSavedToGallery, actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState, take }) => {
const state = getState(); const state = getState();
const blob = await getBaseLayerBlob(state); const blob = await getBaseLayerBlob(state, true);
if (!blob) { if (!blob) {
moduleLog.error('Problem getting base layer blob'); moduleLog.error('Problem getting base layer blob');
@ -27,13 +29,25 @@ export const addCanvasSavedToGalleryListener = () => {
return; return;
} }
const filename = `mergedCanvas_${uuidv4()}.png`;
dispatch( dispatch(
imageUploaded({ imageUploaded({
formData: { formData: {
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), file: new File([blob], filename, { type: 'image/png' }),
}, },
imageCategory: 'general',
isIntermediate: false,
}) })
); );
const [{ payload: uploadedImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === filename
);
dispatch(imageUpserted(uploadedImageDTO));
}, },
}); });
}; };

View File

@ -0,0 +1,24 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedPageOfImages } from 'services/thunks/image';
import {
imageCategoriesChanged,
selectFilteredImagesAsArray,
} from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'gallery' });
export const addImageCategoriesChangedListener = () => {
startAppListening({
actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => {
const filteredImagesCount = selectFilteredImagesAsArray(
getState()
).length;
if (!filteredImagesCount) {
dispatch(receivedPageOfImages());
}
},
});
};

View File

@ -4,8 +4,12 @@ import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice'; import {
import { resultsAdapter } from 'features/gallery/store/resultsSlice'; imageRemoved,
imagesAdapter,
selectImagesEntities,
selectImagesIds,
} from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
@ -22,19 +26,20 @@ export const addRequestedImageDeletionListener = () => {
return; return;
} }
const { image_name, image_type } = image; const { image_name, image_origin } = image;
const selectedImageName = getState().gallery.selectedImage?.image_name; const state = getState();
const selectedImage = state.gallery.selectedImage;
if (selectedImageName === image_name) { if (selectedImage && selectedImage.image_name === image_name) {
const allIds = getState()[image_type].ids; const ids = selectImagesIds(state);
const allEntities = getState()[image_type].entities; const entities = selectImagesEntities(state);
const deletedImageIndex = allIds.findIndex( const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name (result) => result.toString() === image_name
); );
const filteredIds = allIds.filter((id) => id.toString() !== image_name); const filteredIds = ids.filter((id) => id.toString() !== image_name);
const newSelectedImageIndex = clamp( const newSelectedImageIndex = clamp(
deletedImageIndex, deletedImageIndex,
@ -44,7 +49,7 @@ export const addRequestedImageDeletionListener = () => {
const newSelectedImageId = filteredIds[newSelectedImageIndex]; const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = allEntities[newSelectedImageId]; const newSelectedImage = entities[newSelectedImageId];
if (newSelectedImageId) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage)); dispatch(imageSelected(newSelectedImage));
@ -53,7 +58,11 @@ export const addRequestedImageDeletionListener = () => {
} }
} }
dispatch(imageDeleted({ imageName: image_name, imageType: image_type })); dispatch(imageRemoved(image_name));
dispatch(
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
);
}, },
}); });
}; };
@ -65,14 +74,9 @@ export const addImageDeletedPendingListener = () => {
startAppListening({ startAppListening({
actionCreator: imageDeleted.pending, actionCreator: imageDeleted.pending,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { imageName, imageType } = action.meta.arg; const { imageName, imageOrigin } = action.meta.arg;
// Preemptively remove the image from the gallery // Preemptively remove the image from the gallery
if (imageType === 'uploads') { imagesAdapter.removeOne(getState().images, imageName);
uploadsAdapter.removeOne(getState().uploads, imageName);
}
if (imageType === 'results') {
resultsAdapter.removeOne(getState().results, imageName);
}
}, },
}); });
}; };

View File

@ -1,14 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image'; import { imageMetadataReceived } from 'services/thunks/image';
import { import { imageUpserted } from 'features/gallery/store/imagesSlice';
ResultsImageDTO,
resultUpserted,
} from 'features/gallery/store/resultsSlice';
import {
UploadsImageDTO,
uploadUpserted,
} from 'features/gallery/store/uploadsSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -17,15 +10,12 @@ export const addImageMetadataReceivedFulfilledListener = () => {
actionCreator: imageMetadataReceived.fulfilled, actionCreator: imageMetadataReceived.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const image = action.payload; const image = action.payload;
if (image.is_intermediate) {
// No further actions needed for intermediate images
return;
}
moduleLog.debug({ data: { image } }, 'Image metadata received'); moduleLog.debug({ data: { image } }, 'Image metadata received');
dispatch(imageUpserted(image));
if (image.image_type === 'results') {
dispatch(resultUpserted(action.payload as ResultsImageDTO));
}
if (image.image_type === 'uploads') {
dispatch(uploadUpserted(action.payload as UploadsImageDTO));
}
}, },
}); });
}; };

View File

@ -0,0 +1,26 @@
import { startAppListening } from '..';
import { imageUpdated } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ namespace: 'image' });
export const addImageUpdatedFulfilledListener = () => {
startAppListening({
actionCreator: imageUpdated.fulfilled,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
{ oldImage: action.meta.arg, updatedImage: action.payload },
'Image updated'
);
},
});
};
export const addImageUpdatedRejectedListener = () => {
startAppListening({
actionCreator: imageUpdated.rejected,
effect: (action, { dispatch }) => {
moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed');
},
});
};

View File

@ -1,52 +1,28 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { resultUpserted } from 'features/gallery/store/resultsSlice';
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpserted } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
export const addImageUploadedFulfilledListener = () => { export const addImageUploadedFulfilledListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> => actionCreator: imageUploaded.fulfilled,
imageUploaded.fulfilled.match(action) &&
action.payload.is_intermediate === false,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const image = action.payload; const image = action.payload;
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded'); moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
if (action.payload.is_intermediate) {
// No further actions needed for intermediate images
return;
}
const state = getState(); const state = getState();
// Handle uploads dispatch(imageUpserted(image));
if (isUploadsImageDTO(image)) { dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
dispatch(uploadUpserted(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (action.meta.arg.activeTabName === 'img2img') {
dispatch(initialImageSelected(image));
}
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(image));
}
}
// Handle results
// TODO: Can this ever happen? I don't think so...
if (isResultsImageDTO(image)) {
dispatch(resultUpserted(image));
}
}, },
}); });
}; };
@ -55,6 +31,9 @@ export const addImageUploadedRejectedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageUploaded.rejected, actionCreator: imageUploaded.rejected,
effect: (action, { dispatch }) => { effect: (action, { dispatch }) => {
const { formData, ...rest } = action.meta.arg;
const sanitizedData = { arg: { ...rest, formData: { file: '<Blob>' } } };
moduleLog.error({ data: sanitizedData }, 'Image upload failed');
dispatch( dispatch(
addToast({ addToast({
title: 'Image Upload Failed', title: 'Image Upload Failed',

View File

@ -1,8 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageUrlsReceived } from 'services/thunks/image'; import { imageUrlsReceived } from 'services/thunks/image';
import { resultsAdapter } from 'features/gallery/store/resultsSlice'; import { imagesAdapter } from 'features/gallery/store/imagesSlice';
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -13,27 +12,15 @@ export const addImageUrlsReceivedFulfilledListener = () => {
const image = action.payload; const image = action.payload;
moduleLog.debug({ data: { image } }, 'Image URLs received'); moduleLog.debug({ data: { image } }, 'Image URLs received');
const { image_type, image_name, image_url, thumbnail_url } = image; const { image_name, image_url, thumbnail_url } = image;
if (image_type === 'results') { imagesAdapter.updateOne(getState().images, {
resultsAdapter.updateOne(getState().results, { id: image_name,
id: image_name, changes: {
changes: { image_url,
image_url, thumbnail_url,
thumbnail_url, },
}, });
});
}
if (image_type === 'uploads') {
uploadsAdapter.updateOne(getState().uploads, {
id: image_name,
changes: {
image_url,
thumbnail_url,
},
});
}
}, },
}); });
}; };

View File

@ -1,6 +1,4 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -9,7 +7,7 @@ import {
isImageDTO, isImageDTO,
} from 'features/parameters/store/actions'; } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { ImageDTO } from 'services/api'; import { selectImagesById } from 'features/gallery/store/imagesSlice';
export const addInitialImageSelectedListener = () => { export const addInitialImageSelectedListener = () => {
startAppListening({ startAppListening({
@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => {
return; return;
} }
const { image_name, image_type } = action.payload; const imageName = action.payload;
const image = selectImagesById(getState(), imageName);
let image: ImageDTO | undefined;
const state = getState();
if (image_type === 'results') {
image = selectResultsById(state, image_name);
} else if (image_type === 'uploads') {
image = selectUploadsById(state, image_name);
}
if (!image) { if (!image) {
dispatch( dispatch(

View File

@ -1,31 +1,31 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { receivedResultImagesPage } from 'services/thunks/gallery';
import { serializeError } from 'serialize-error'; import { serializeError } from 'serialize-error';
import { receivedPageOfImages } from 'services/thunks/image';
const moduleLog = log.child({ namespace: 'gallery' }); const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedResultImagesPageFulfilledListener = () => { export const addReceivedPageOfImagesFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: receivedResultImagesPage.fulfilled, actionCreator: receivedPageOfImages.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const page = action.payload; const page = action.payload;
moduleLog.debug( moduleLog.debug(
{ data: { page } }, { data: { payload: action.payload } },
`Received ${page.items.length} results` `Received ${page.items.length} images`
); );
}, },
}); });
}; };
export const addReceivedResultImagesPageRejectedListener = () => { export const addReceivedPageOfImagesRejectedListener = () => {
startAppListening({ startAppListening({
actionCreator: receivedResultImagesPage.rejected, actionCreator: receivedPageOfImages.rejected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
if (action.payload) { if (action.payload) {
moduleLog.debug( moduleLog.debug(
{ data: { error: serializeError(action.payload.error) } }, { data: { error: serializeError(action.payload) } },
'Problem receiving results' 'Problem receiving images'
); );
} }
}, },

View File

@ -1,33 +0,0 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedUploadImagesPage } from 'services/thunks/gallery';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedUploadImagesPageFulfilledListener = () => {
startAppListening({
actionCreator: receivedUploadImagesPage.fulfilled,
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { page } },
`Received ${page.items.length} uploads`
);
},
});
};
export const addReceivedUploadImagesPageRejectedListener = () => {
startAppListening({
actionCreator: receivedUploadImagesPage.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
moduleLog.debug(
{ data: { error: serializeError(action.payload.error) } },
'Problem receiving uploads'
);
}
},
});
};

View File

@ -1,16 +1,13 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { import { receivedPageOfImages } from 'services/thunks/image';
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { receivedModels } from 'services/thunks/model'; import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketConnectedListener = () => { export const addSocketConnectedEventListener = () => {
startAppListening({ startAppListening({
actionCreator: socketConnected, actionCreator: socketConnected,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
@ -18,17 +15,12 @@ export const addSocketConnectedListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { results, uploads, models, nodes, config } = getState(); const { models, nodes, config, images } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
// These thunks need to be dispatch in middleware; cannot handle in a reducer if (!images.ids.length) {
if (!results.ids.length) { dispatch(receivedPageOfImages());
dispatch(receivedResultImagesPage());
}
if (!uploads.ids.length) {
dispatch(receivedUploadImagesPage());
} }
if (!models.ids.length) { if (!models.ids.length) {
@ -38,6 +30,9 @@ export const addSocketConnectedListener = () => {
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
}, },
}); });
}; };

View File

@ -1,14 +1,19 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { socketDisconnected } from 'services/events/actions'; import {
socketDisconnected,
appSocketDisconnected,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketDisconnectedListener = () => { export const addSocketDisconnectedEventListener = () => {
startAppListening({ startAppListening({
actionCreator: socketDisconnected, actionCreator: socketDisconnected,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
moduleLog.debug(action.payload, 'Disconnected'); moduleLog.debug(action.payload, 'Disconnected');
// pass along the socket event as an application action
dispatch(appSocketDisconnected(action.payload));
}, },
}); });
}; };

View File

@ -1,12 +1,15 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { generatorProgress } from 'services/events/actions'; import {
appSocketGeneratorProgress,
socketGeneratorProgress,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addGeneratorProgressListener = () => { export const addGeneratorProgressEventListener = () => {
startAppListening({ startAppListening({
actionCreator: generatorProgress, actionCreator: socketGeneratorProgress,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
if ( if (
getState().system.canceledSession === getState().system.canceledSession ===
@ -23,6 +26,9 @@ export const addGeneratorProgressListener = () => {
action.payload, action.payload,
`Generator progress (${action.payload.data.node.type})` `Generator progress (${action.payload.data.node.type})`
); );
// pass along the socket event as an application action
dispatch(appSocketGeneratorProgress(action.payload));
}, },
}); });
}; };

View File

@ -1,17 +1,22 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { graphExecutionStateComplete } from 'services/events/actions'; import {
appSocketGraphExecutionStateComplete,
socketGraphExecutionStateComplete,
} from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addGraphExecutionStateCompleteListener = () => { export const addGraphExecutionStateCompleteEventListener = () => {
startAppListening({ startAppListening({
actionCreator: graphExecutionStateComplete, actionCreator: socketGraphExecutionStateComplete,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
moduleLog.debug( moduleLog.debug(
action.payload, action.payload,
`Session invocation complete (${action.payload.data.graph_execution_state_id})` `Session invocation complete (${action.payload.data.graph_execution_state_id})`
); );
// pass along the socket event as an application action
dispatch(appSocketGraphExecutionStateComplete(action.payload));
}, },
}); });
}; };

View File

@ -1,19 +1,21 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { invocationComplete } from 'services/events/actions'; import {
appSocketInvocationComplete,
socketInvocationComplete,
} from 'services/events/actions';
import { imageMetadataReceived } from 'services/thunks/image'; import { imageMetadataReceived } from 'services/thunks/image';
import { sessionCanceled } from 'services/thunks/session'; import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards'; import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice'; import { progressImageSet } from 'features/system/store/systemSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['dataURL_image'];
export const addInvocationCompleteListener = () => { export const addInvocationCompleteEventListener = () => {
startAppListening({ startAppListening({
actionCreator: invocationComplete, actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState, take }) => { effect: async (action, { dispatch, getState, take }) => {
moduleLog.debug( moduleLog.debug(
action.payload, action.payload,
@ -34,13 +36,13 @@ export const addInvocationCompleteListener = () => {
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_type } = result.image; const { image_name, image_origin } = result.image;
// Get its metadata // Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageMetadataReceived({
imageName: image_name, imageName: image_name,
imageType: image_type, imageOrigin: image_origin,
}) })
); );
@ -48,27 +50,18 @@ export const addInvocationCompleteListener = () => {
imageMetadataReceived.fulfilled.match imageMetadataReceived.fulfilled.match
); );
if (getState().gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(imageDTO));
}
// Handle canvas image // Handle canvas image
if ( if (
graph_execution_state_id === graph_execution_state_id ===
getState().canvas.layerState.stagingArea.sessionId getState().canvas.layerState.stagingArea.sessionId
) { ) {
const [{ payload: image }] = await take( dispatch(addImageToStagingArea(imageDTO));
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
dispatch(addImageToStagingArea(image));
} }
dispatch(progressImageSet(null)); dispatch(progressImageSet(null));
} }
// pass along the socket event as an application action
dispatch(appSocketInvocationComplete(action.payload));
}, },
}); });
}; };

View File

@ -1,17 +1,21 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { invocationError } from 'services/events/actions'; import {
appSocketInvocationError,
socketInvocationError,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationErrorListener = () => { export const addInvocationErrorEventListener = () => {
startAppListening({ startAppListening({
actionCreator: invocationError, actionCreator: socketInvocationError,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
moduleLog.error( moduleLog.error(
action.payload, action.payload,
`Invocation error (${action.payload.data.node.type})` `Invocation error (${action.payload.data.node.type})`
); );
dispatch(appSocketInvocationError(action.payload));
}, },
}); });
}; };

View File

@ -1,12 +1,15 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { invocationStarted } from 'services/events/actions'; import {
appSocketInvocationStarted,
socketInvocationStarted,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationStartedListener = () => { export const addInvocationStartedEventListener = () => {
startAppListening({ startAppListening({
actionCreator: invocationStarted, actionCreator: socketInvocationStarted,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
if ( if (
getState().system.canceledSession === getState().system.canceledSession ===
@ -23,6 +26,7 @@ export const addInvocationStartedListener = () => {
action.payload, action.payload,
`Invocation started (${action.payload.data.node.type})` `Invocation started (${action.payload.data.node.type})`
); );
dispatch(appSocketInvocationStarted(action.payload));
}, },
}); });
}; };

View File

@ -1,10 +1,10 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { socketSubscribed } from 'services/events/actions'; import { appSocketSubscribed, socketSubscribed } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketSubscribedListener = () => { export const addSocketSubscribedEventListener = () => {
startAppListening({ startAppListening({
actionCreator: socketSubscribed, actionCreator: socketSubscribed,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
@ -12,6 +12,7 @@ export const addSocketSubscribedListener = () => {
action.payload, action.payload,
`Subscribed (${action.payload.sessionId}))` `Subscribed (${action.payload.sessionId}))`
); );
dispatch(appSocketSubscribed(action.payload));
}, },
}); });
}; };

View File

@ -1,10 +1,13 @@
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { socketUnsubscribed } from 'services/events/actions'; import {
appSocketUnsubscribed,
socketUnsubscribed,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketUnsubscribedListener = () => { export const addSocketUnsubscribedEventListener = () => {
startAppListening({ startAppListening({
actionCreator: socketUnsubscribed, actionCreator: socketUnsubscribed,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
@ -12,6 +15,7 @@ export const addSocketUnsubscribedListener = () => {
action.payload, action.payload,
`Unsubscribed (${action.payload.sessionId})` `Unsubscribed (${action.payload.sessionId})`
); );
dispatch(appSocketUnsubscribed(action.payload));
}, },
}); });
}; };

View File

@ -0,0 +1,54 @@
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { imageUpdated } from 'services/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice';
import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvas' });
export const addStagingAreaImageSavedListener = () => {
startAppListening({
actionCreator: stagingAreaImageSaved,
effect: async (action, { dispatch, getState, take }) => {
const { image_name, image_origin } = action.payload;
dispatch(
imageUpdated({
imageName: image_name,
imageOrigin: image_origin,
requestBody: {
is_intermediate: false,
},
})
);
const [imageUpdatedAction] = await take(
(action) =>
(imageUpdated.fulfilled.match(action) ||
imageUpdated.rejected.match(action)) &&
action.meta.arg.imageName === image_name
);
if (imageUpdated.rejected.match(imageUpdatedAction)) {
moduleLog.error(
{ data: { arg: imageUpdatedAction.meta.arg } },
'Image saving failed'
);
dispatch(
addToast({
title: 'Image Saving Failed',
description: imageUpdatedAction.error.message,
status: 'error',
})
);
return;
}
if (imageUpdated.fulfilled.match(imageUpdatedAction)) {
dispatch(imageUpserted(imageUpdatedAction.payload));
dispatch(addToast({ title: 'Image Saved', status: 'success' }));
}
},
});
};

View File

@ -101,6 +101,7 @@ export const addUserInvokedCanvasListener = () => {
formData: { formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }), file: new File([baseBlob], baseFilename, { type: 'image/png' }),
}, },
imageCategory: 'general',
isIntermediate: true, isIntermediate: true,
}) })
); );
@ -115,7 +116,7 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type // Update the base node with the image name and type
baseNode.image = { baseNode.image = {
image_name: baseImageDTO.image_name, image_name: baseImageDTO.image_name,
image_type: baseImageDTO.image_type, image_origin: baseImageDTO.image_origin,
}; };
} }
@ -127,6 +128,7 @@ export const addUserInvokedCanvasListener = () => {
formData: { formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }), file: new File([maskBlob], maskFilename, { type: 'image/png' }),
}, },
imageCategory: 'mask',
isIntermediate: true, isIntermediate: true,
}) })
); );
@ -141,7 +143,7 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type // Update the base node with the image name and type
baseNode.mask = { baseNode.mask = {
image_name: maskImageDTO.image_name, image_name: maskImageDTO.image_name,
image_type: maskImageDTO.image_type, image_origin: maskImageDTO.image_origin,
}; };
} }
@ -158,7 +160,7 @@ export const addUserInvokedCanvasListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: baseNode.image.image_name, imageName: baseNode.image.image_name,
imageType: baseNode.image.image_type, imageOrigin: baseNode.image.image_origin,
requestBody: { session_id: sessionId }, requestBody: { session_id: sessionId },
}) })
); );
@ -169,7 +171,7 @@ export const addUserInvokedCanvasListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: baseNode.mask.image_name, imageName: baseNode.mask.image_name,
imageType: baseNode.mask.image_type, imageOrigin: baseNode.mask.image_origin,
requestBody: { session_id: sessionId }, requestBody: { session_id: sessionId },
}) })
); );

View File

@ -10,8 +10,7 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import resultsReducer from 'features/gallery/store/resultsSlice'; import imagesReducer from 'features/gallery/store/imagesSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
@ -41,12 +40,11 @@ const allReducers = {
models: modelsReducer, models: modelsReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
results: resultsReducer,
system: systemReducer, system: systemReducer,
config: configReducer, config: configReducer,
ui: uiReducer, ui: uiReducer,
uploads: uploadsReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
images: imagesReducer,
// session: sessionReducer, // session: sessionReducer,
}; };
@ -65,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system', 'system',
'ui', 'ui',
// 'hotkeys', // 'hotkeys',
// 'results',
// 'uploads',
// 'config', // 'config',
]; ];

View File

@ -15,7 +15,7 @@
import { SelectedImage } from 'features/parameters/store/actions'; import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types'; import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ImageType } from 'services/api'; import { ImageResponseMetadata, ResourceOrigin } from 'services/api';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
/** /**
@ -124,7 +124,7 @@ export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
*/ */
// export ty`pe Image = { // export ty`pe Image = {
// name: string; // name: string;
// type: ImageType; // type: image_origin;
// url: string; // url: string;
// thumbnail: string; // thumbnail: string;
// metadata: ImageResponseMetadata; // metadata: ImageResponseMetadata;

View File

@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
type ImageUploadOverlayProps = { type ImageUploadOverlayProps = {
isDragAccept: boolean; isDragAccept: boolean;
isDragReject: boolean; isDragReject: boolean;
overlaySecondaryText: string;
setIsHandlingUpload: (isHandlingUpload: boolean) => void; setIsHandlingUpload: (isHandlingUpload: boolean) => void;
}; };
@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const { const {
isDragAccept, isDragAccept,
isDragReject: _isDragAccept, isDragReject: _isDragAccept,
overlaySecondaryText,
setIsHandlingUpload, setIsHandlingUpload,
} = props; } = props;
@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
}} }}
> >
{isDragAccept ? ( {isDragAccept ? (
<Heading size="lg">Upload Image{overlaySecondaryText}</Heading> <Heading size="lg">Drop to Upload</Heading>
) : ( ) : (
<> <>
<Heading size="lg">Invalid Upload</Heading> <Heading size="lg">Invalid Upload</Heading>

View File

@ -69,11 +69,12 @@ const ImageUploader = (props: ImageUploaderProps) => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
formData: { file }, formData: { file },
activeTabName, imageCategory: 'user',
isIntermediate: false,
}) })
); );
}, },
[dispatch, activeTabName] [dispatch]
); );
const onDrop = useCallback( const onDrop = useCallback(
@ -144,14 +145,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
}; };
}, [inputRef, open, setOpenUploaderFunction]); }, [inputRef, open, setOpenUploaderFunction]);
const overlaySecondaryText = useMemo(() => {
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
}
return '';
}, [t, activeTabName]);
return ( return (
<Box <Box
{...getRootProps({ style: {} })} {...getRootProps({ style: {} })}
@ -166,7 +159,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
<ImageUploadOverlay <ImageUploadOverlay
isDragAccept={isDragAccept} isDragAccept={isDragAccept}
isDragReject={isDragReject} isDragReject={isDragReject}
overlaySecondaryText={overlaySecondaryText}
setIsHandlingUpload={setIsHandlingUpload} setIsHandlingUpload={setIsHandlingUpload}
/> />
)} )}

View File

@ -1,239 +0,0 @@
import { forEach, size } from 'lodash-es';
import {
ImageField,
LatentsField,
ConditioningField,
ControlField,
} from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
const NUMBER_TYPESTRING = '[object Number]';
const BOOLEAN_TYPESTRING = '[object Boolean]';
const ARRAY_TYPESTRING = '[object Array]';
const isObject = (obj: unknown): obj is Record<string | number, any> =>
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
const isString = (obj: unknown): obj is string =>
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
const isNumber = (obj: unknown): obj is number =>
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
const isBoolean = (obj: unknown): obj is boolean =>
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
const isArray = (obj: unknown): obj is Array<any> =>
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
const parseImageField = (imageField: unknown): ImageField | undefined => {
// Must be an object
if (!isObject(imageField)) {
return;
}
// An ImageField must have both `image_name` and `image_type`
if (!('image_name' in imageField && 'image_type' in imageField)) {
return;
}
// An ImageField's `image_type` must be one of the allowed values
if (
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
) {
return;
}
// An ImageField's `image_name` must be a string
if (typeof imageField.image_name !== 'string') {
return;
}
// Build a valid ImageField
return {
image_type: imageField.image_type,
image_name: imageField.image_name,
};
};
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
// Must be an object
if (!isObject(latentsField)) {
return;
}
// A LatentsField must have a `latents_name`
if (!('latents_name' in latentsField)) {
return;
}
// A LatentsField's `latents_name` must be a string
if (typeof latentsField.latents_name !== 'string') {
return;
}
// Build a valid LatentsField
return {
latents_name: latentsField.latents_name,
};
};
const parseConditioningField = (
conditioningField: unknown
): ConditioningField | undefined => {
// Must be an object
if (!isObject(conditioningField)) {
return;
}
// A ConditioningField must have a `conditioning_name`
if (!('conditioning_name' in conditioningField)) {
return;
}
// A ConditioningField's `conditioning_name` must be a string
if (typeof conditioningField.conditioning_name !== 'string') {
return;
}
// Build a valid ConditioningField
return {
conditioning_name: conditioningField.conditioning_name,
};
};
const parseControlField = (controlField: unknown): ControlField | undefined => {
// Must be an object
if (!isObject(controlField)) {
return;
}
// A ControlField must have a `control`
if (!('control' in controlField)) {
return;
}
// console.log(typeof controlField.control);
// Build a valid ControlField
return {
control: controlField.control,
};
};
type NodeMetadata = {
[key: string]:
| string
| number
| boolean
| ImageField
| LatentsField
| ConditioningField
| ControlField;
};
type InvokeAIMetadata = {
session_id?: string;
node?: NodeMetadata;
};
export const parseNodeMetadata = (
nodeMetadata: Record<string | number, any>
): NodeMetadata | undefined => {
if (!isObject(nodeMetadata)) {
return;
}
const parsed: NodeMetadata = {};
forEach(nodeMetadata, (nodeItem, nodeKey) => {
// `id` and `type` must be strings if they are present
if (['id', 'type'].includes(nodeKey)) {
if (isString(nodeItem)) {
parsed[nodeKey] = nodeItem;
}
return;
}
// the only valid object types are ImageField, LatentsField, ConditioningField, ControlField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
if (imageField) {
parsed[nodeKey] = imageField;
}
return;
}
if ('latents_name' in nodeItem) {
const latentsField = parseLatentsField(nodeItem);
if (latentsField) {
parsed[nodeKey] = latentsField;
}
return;
}
if ('conditioning_name' in nodeItem) {
const conditioningField = parseConditioningField(nodeItem);
if (conditioningField) {
parsed[nodeKey] = conditioningField;
}
return;
}
if ('control' in nodeItem) {
const controlField = parseControlField(nodeItem);
if (controlField) {
parsed[nodeKey] = controlField;
}
return;
}
}
// otherwise we accept any string, number or boolean
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
parsed[nodeKey] = nodeItem;
return;
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};
export const parseInvokeAIMetadata = (
metadata: Record<string | number, any> | undefined
): InvokeAIMetadata | undefined => {
if (metadata === undefined) {
return;
}
if (!isObject(metadata)) {
return;
}
const parsed: InvokeAIMetadata = {};
forEach(metadata, (item, key) => {
if (key === 'session_id' && isString(item)) {
parsed['session_id'] = item;
}
if (key === 'node' && isObject(item)) {
const nodeMetadata = parseNodeMetadata(item);
if (nodeMetadata) {
parsed['node'] = nodeMetadata;
}
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};

View File

@ -1,18 +1,24 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl'; import { systemSelector } from 'features/system/store/systemSelectors';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image'; import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { Image as KonvaImage } from 'react-konva'; import { Image as KonvaImage } from 'react-konva';
import { canvasSelector } from '../store/canvasSelectors';
const selector = createSelector( const selector = createSelector(
[(state: RootState) => state.gallery], [systemSelector, canvasSelector],
(gallery: GalleryState) => { (system, canvas) => {
return gallery.intermediateImage ? gallery.intermediateImage : null; const { progressImage, sessionId } = system;
const { sessionId: canvasSessionId, boundingBox } =
canvas.layerState.stagingArea;
return {
boundingBox,
progressImage: sessionId === canvasSessionId ? progressImage : undefined,
};
}, },
{ {
memoizeOptions: { memoizeOptions: {
@ -25,33 +31,34 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => { const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props; const { ...rest } = props;
const intermediateImage = useAppSelector(selector); const { progressImage, boundingBox } = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] = const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null); useState<HTMLImageElement | null>(null);
useEffect(() => { useEffect(() => {
if (!intermediateImage) return; if (!progressImage) {
return;
}
const tempImage = new Image(); const tempImage = new Image();
tempImage.onload = () => { tempImage.onload = () => {
setLoadedImageElement(tempImage); setLoadedImageElement(tempImage);
}; };
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
if (!intermediateImage?.boundingBox) return null; tempImage.src = progressImage.dataURL;
}, [progressImage]);
const { if (!(progressImage && boundingBox)) {
boundingBox: { x, y, width, height }, return null;
} = intermediateImage; }
return loadedImageElement ? ( return loadedImageElement ? (
<KonvaImage <KonvaImage
x={x} x={boundingBox.x}
y={y} y={boundingBox.y}
width={width} width={boundingBox.width}
height={height} height={boundingBox.height}
image={loadedImageElement} image={loadedImageElement}
listening={false} listening={false}
{...rest} {...rest}

View File

@ -62,7 +62,7 @@ const IAICanvasStagingArea = (props: Props) => {
<Group {...rest}> <Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && ( {shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage <IAICanvasImage
url={getUrl(currentStagingAreaImage.image.image_url)} url={getUrl(currentStagingAreaImage.image.image_url) ?? ''}
x={x} x={x}
y={y} y={y}
/> />

View File

@ -1,6 +1,5 @@
import { ButtonGroup, Flex } from '@chakra-ui/react'; import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
// import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
@ -26,13 +25,14 @@ import {
FaPlus, FaPlus,
FaSave, FaSave,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { stagingAreaImageSaved } from '../store/actions';
const selector = createSelector( const selector = createSelector(
[canvasSelector], [canvasSelector],
(canvas) => { (canvas) => {
const { const {
layerState: { layerState: {
stagingArea: { images, selectedImageIndex }, stagingArea: { images, selectedImageIndex, sessionId },
}, },
shouldShowStagingOutline, shouldShowStagingOutline,
shouldShowStagingImage, shouldShowStagingImage,
@ -45,6 +45,7 @@ const selector = createSelector(
isOnLastImage: selectedImageIndex === images.length - 1, isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage, shouldShowStagingImage,
shouldShowStagingOutline, shouldShowStagingOutline,
sessionId,
}; };
}, },
{ {
@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => {
isOnLastImage, isOnLastImage,
currentStagingAreaImage, currentStagingAreaImage,
shouldShowStagingImage, shouldShowStagingImage,
sessionId,
} = useAppSelector(selector); } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => {
} }
); );
const handlePrevImage = () => dispatch(prevStagingAreaImage()); const handlePrevImage = useCallback(
const handleNextImage = () => dispatch(nextStagingAreaImage()); () => dispatch(prevStagingAreaImage()),
const handleAccept = () => dispatch(commitStagingAreaImage()); [dispatch]
);
const handleNextImage = useCallback(
() => dispatch(nextStagingAreaImage()),
[dispatch]
);
const handleAccept = useCallback(
() => dispatch(commitStagingAreaImage(sessionId)),
[dispatch, sessionId]
);
if (!currentStagingAreaImage) return null; if (!currentStagingAreaImage) return null;
@ -157,19 +170,15 @@ const IAICanvasStagingAreaToolbar = () => {
} }
colorScheme="accent" colorScheme="accent"
/> />
{/* <IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.saveToGallery')} tooltip={t('unifiedCanvas.saveToGallery')}
aria-label={t('unifiedCanvas.saveToGallery')} aria-label={t('unifiedCanvas.saveToGallery')}
icon={<FaSave />} icon={<FaSave />}
onClick={() => onClick={() =>
dispatch( dispatch(stagingAreaImageSaved(currentStagingAreaImage.image))
saveStagingAreaImageToGallery(
currentStagingAreaImage.image.image_url
)
)
} }
colorScheme="accent" colorScheme="accent"
/> */} />
<IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.discardAll')} tooltip={t('unifiedCanvas.discardAll')}
aria-label={t('unifiedCanvas.discardAll')} aria-label={t('unifiedCanvas.discardAll')}

View File

@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery'); export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@ -11,3 +12,7 @@ export const canvasDownloadedAsImage = createAction(
); );
export const canvasMerged = createAction('canvas/canvasMerged'); export const canvasMerged = createAction('canvas/canvasMerged');
export const stagingAreaImageSaved = createAction<ImageDTO>(
'canvas/stagingAreaImageSaved'
);

View File

@ -696,7 +696,10 @@ export const canvasSlice = createSlice({
0 0
); );
}, },
commitStagingAreaImage: (state) => { commitStagingAreaImage: (
state,
action: PayloadAction<string | undefined>
) => {
if (!state.layerState.stagingArea.images.length) { if (!state.layerState.stagingArea.images.length) {
return; return;
} }

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual, isString } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { import {
ButtonGroup, ButtonGroup,
@ -25,8 +25,8 @@ import {
} from 'features/ui/store/uiSelectors'; } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
setShouldHidePreview,
setShouldShowImageDetails, setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice'; } from 'features/ui/store/uiSlice';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -37,18 +37,14 @@ import {
FaDownload, FaDownload,
FaExpand, FaExpand,
FaExpandArrowsAlt, FaExpandArrowsAlt,
FaEye,
FaEyeSlash,
FaGrinStars, FaGrinStars,
FaHourglassHalf,
FaQuoteRight, FaQuoteRight,
FaSeedling, FaSeedling,
FaShare, FaShare,
FaShareAlt, FaShareAlt,
FaTrash,
FaWrench,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors'; import { gallerySelector } from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
@ -90,7 +86,11 @@ const currentImageButtonsSelector = createSelector(
const { isLightboxOpen } = lightbox; const { isLightboxOpen } = lightbox;
const { shouldShowImageDetails, shouldHidePreview } = ui; const {
shouldShowImageDetails,
shouldHidePreview,
shouldShowProgressInViewer,
} = ui;
const { selectedImage } = gallery; const { selectedImage } = gallery;
@ -112,6 +112,7 @@ const currentImageButtonsSelector = createSelector(
seed: selectedImage?.metadata?.seed, seed: selectedImage?.metadata?.seed,
prompt: selectedImage?.metadata?.positive_conditioning, prompt: selectedImage?.metadata?.positive_conditioning,
negativePrompt: selectedImage?.metadata?.negative_conditioning, negativePrompt: selectedImage?.metadata?.negative_conditioning,
shouldShowProgressInViewer,
}; };
}, },
{ {
@ -145,6 +146,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
image, image,
canDeleteImage, canDeleteImage,
shouldConfirmOnDelete, shouldConfirmOnDelete,
shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector); } = useAppSelector(currentImageButtonsSelector);
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}); });
}, [toaster, shouldTransformUrls, getUrl, t, image]); }, [toaster, shouldTransformUrls, getUrl, t, image]);
const handlePreviewVisibility = useCallback(() => {
dispatch(setShouldHidePreview(!shouldHidePreview));
}, [dispatch, shouldHidePreview]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(image); recallAllParameters(image);
}, [image, recallAllParameters]); }, [image, recallAllParameters]);
@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
} }
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]); }, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
const handleClickProgressImagesToggle = useCallback(() => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]);
useHotkeys('delete', handleInitiateDelete, [ useHotkeys('delete', handleInitiateDelete, [
image, image,
shouldConfirmOnDelete, shouldConfirmOnDelete,
@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
isDisabled={!image}
aria-label={`${t('parameters.sendTo')}...`} aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!image}
icon={<FaShareAlt />} icon={<FaShareAlt />}
/> />
} }
@ -465,21 +468,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</Link> </Link>
</Flex> </Flex>
</IAIPopover> </IAIPopover>
{/* <IAIIconButton
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
tooltip={
!shouldHidePreview
? t('parameters.hidePreview')
: t('parameters.showPreview')
}
aria-label={
!shouldHidePreview
? t('parameters.hidePreview')
: t('parameters.showPreview')
}
isChecked={shouldHidePreview}
onClick={handlePreviewVisibility}
/> */}
{isLightboxEnabled && ( {isLightboxEnabled && (
<IAIIconButton <IAIIconButton
icon={<FaExpand />} icon={<FaExpand />}
@ -604,6 +592,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/> />
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}>
<IAIIconButton
aria-label={t('settings.displayInProgress')}
tooltip={t('settings.displayInProgress')}
icon={<FaHourglassHalf />}
isChecked={shouldShowProgressInViewer}
onClick={handleClickProgressImagesToggle}
/>
</ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true}>
<DeleteImageButton image={image} /> <DeleteImageButton image={image} />
</ButtonGroup> </ButtonGroup>

View File

@ -62,7 +62,6 @@ const CurrentImagePreview = () => {
return; return;
} }
e.dataTransfer.setData('invokeai/imageName', image.image_name); e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move'; e.dataTransfer.effectAllowed = 'move';
}, },
[image] [image]

View File

@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleDragStart = useCallback( const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageName', image.image_name); e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move'; e.dataTransfer.effectAllowed = 'move';
}, },
[image] [image]

View File

@ -1,6 +1,8 @@
import { import {
Box, Box,
ButtonGroup, ButtonGroup,
Checkbox,
CheckboxGroup,
Flex, Flex,
FlexProps, FlexProps,
Grid, Grid,
@ -16,7 +18,6 @@ import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { import {
setCurrentCategory,
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
setGalleryImageObjectFit, setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages, setShouldAutoSwitchToNewImages,
@ -36,54 +37,48 @@ import {
} from 'react'; } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
import { FaImage, FaUser, FaWrench } from 'react-icons/fa'; import {
FaFilter,
FaImage,
FaImages,
FaServer,
FaWrench,
} from 'react-icons/fa';
import { MdPhotoLibrary } from 'react-icons/md'; import { MdPhotoLibrary } from 'react-icons/md';
import HoverableImage from './HoverableImage'; import HoverableImage from './HoverableImage';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { resultsAdapter } from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import GalleryProgressImage from './GalleryProgressImage';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { ImageDTO } from 'services/api'; import { ImageCategory } from 'services/api';
import {
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290; ASSETS_CATEGORIES,
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER'; IMAGE_CATEGORIES,
imageCategoriesChanged,
selectImagesAll,
} from '../store/imagesSlice';
import { receivedPageOfImages } from 'services/thunks/image';
import { capitalize } from 'lodash-es';
const categorySelector = createSelector( const categorySelector = createSelector(
[(state: RootState) => state], [(state: RootState) => state],
(state) => { (state) => {
const { results, uploads, system, gallery } = state; const { images } = state;
const { currentCategory } = gallery; const { categories } = images;
if (currentCategory === 'results') { const allImages = selectImagesAll(state);
const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = []; const filteredImages = allImages.filter((i) =>
categories.includes(i.image_category)
if (system.progressImage) { );
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
}
return {
images: tempImages.concat(
resultsAdapter.getSelectors().selectAll(results)
),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
};
}
return { return {
images: uploadsAdapter.getSelectors().selectAll(uploads), images: filteredImages,
isLoading: uploads.isLoading, isLoading: images.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1, areMoreImagesAvailable: filteredImages.length < images.total,
categories: images.categories,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -93,7 +88,6 @@ const mainSelector = createSelector(
[gallerySelector, uiSelector], [gallerySelector, uiSelector],
(gallery, ui) => { (gallery, ui) => {
const { const {
currentCategory,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
shouldAutoSwitchToNewImages, shouldAutoSwitchToNewImages,
@ -104,7 +98,6 @@ const mainSelector = createSelector(
const { shouldPinGallery } = ui; const { shouldPinGallery } = ui;
return { return {
currentCategory,
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
@ -120,7 +113,6 @@ const ImageGalleryContent = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const resizeObserverRef = useRef<HTMLDivElement>(null); const resizeObserverRef = useRef<HTMLDivElement>(null);
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const rootRef = useRef(null); const rootRef = useRef(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null); const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({ const [initialize, osInstance] = useOverlayScrollbars({
@ -137,7 +129,6 @@ const ImageGalleryContent = () => {
}); });
const { const {
currentCategory,
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
@ -146,18 +137,12 @@ const ImageGalleryContent = () => {
selectedImage, selectedImage,
} = useAppSelector(mainSelector); } = useAppSelector(mainSelector);
const { images, areMoreImagesAvailable, isLoading } = const { images, areMoreImagesAvailable, isLoading, categories } =
useAppSelector(categorySelector); useAppSelector(categorySelector);
const handleClickLoadMore = () => { const handleLoadMoreImages = useCallback(() => {
if (currentCategory === 'results') { dispatch(receivedPageOfImages());
dispatch(receivedResultImagesPage()); }, [dispatch]);
}
if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage());
}
};
const handleChangeGalleryImageMinimumWidth = (v: number) => { const handleChangeGalleryImageMinimumWidth = (v: number) => {
dispatch(setGalleryImageMinimumWidth(v)); dispatch(setGalleryImageMinimumWidth(v));
@ -168,28 +153,6 @@ const ImageGalleryContent = () => {
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
}; };
useEffect(() => {
if (!resizeObserverRef.current) {
return;
}
const resizeObserver = new ResizeObserver(() => {
if (!resizeObserverRef.current) {
return;
}
if (
resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH
) {
setShouldShouldIconButtons(true);
return;
}
setShouldShouldIconButtons(false);
});
resizeObserver.observe(resizeObserverRef.current);
return () => resizeObserver.disconnect(); // clean up
}, []);
useEffect(() => { useEffect(() => {
const { current: root } = rootRef; const { current: root } = rootRef;
if (scroller && root) { if (scroller && root) {
@ -210,12 +173,23 @@ const ImageGalleryContent = () => {
}, []); }, []);
const handleEndReached = useCallback(() => { const handleEndReached = useCallback(() => {
if (currentCategory === 'results') { handleLoadMoreImages();
dispatch(receivedResultImagesPage()); }, [handleLoadMoreImages]);
} else if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage()); const handleCategoriesChanged = useCallback(
} (newCategories: ImageCategory[]) => {
}, [dispatch, currentCategory]); dispatch(imageCategoriesChanged(newCategories));
},
[dispatch]
);
const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
}, [dispatch]);
const handleClickAssetsCategory = useCallback(() => {
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
}, [dispatch]);
return ( return (
<Flex <Flex
@ -232,59 +206,31 @@ const ImageGalleryContent = () => {
alignItems="center" alignItems="center"
justifyContent="space-between" justifyContent="space-between"
> >
<ButtonGroup <ButtonGroup isAttached>
size="sm" <IAIIconButton
isAttached tooltip={t('gallery.images')}
w="max-content" aria-label={t('gallery.images')}
justifyContent="stretch" onClick={handleClickImagesCategory}
> isChecked={categories === IMAGE_CATEGORIES}
{shouldShouldIconButtons ? ( size="sm"
<> icon={<FaImage />}
<IAIIconButton />
aria-label={t('gallery.showGenerations')} <IAIIconButton
tooltip={t('gallery.showGenerations')} tooltip={t('gallery.assets')}
isChecked={currentCategory === 'results'} aria-label={t('gallery.assets')}
role="radio" onClick={handleClickAssetsCategory}
icon={<FaImage />} isChecked={categories === ASSETS_CATEGORIES}
onClick={() => dispatch(setCurrentCategory('results'))} size="sm"
/> icon={<FaServer />}
<IAIIconButton />
aria-label={t('gallery.showUploads')}
tooltip={t('gallery.showUploads')}
role="radio"
isChecked={currentCategory === 'uploads'}
icon={<FaUser />}
onClick={() => dispatch(setCurrentCategory('uploads'))}
/>
</>
) : (
<>
<IAIButton
size="sm"
isChecked={currentCategory === 'results'}
onClick={() => dispatch(setCurrentCategory('results'))}
flexGrow={1}
>
{t('gallery.generations')}
</IAIButton>
<IAIButton
size="sm"
isChecked={currentCategory === 'uploads'}
onClick={() => dispatch(setCurrentCategory('uploads'))}
flexGrow={1}
>
{t('gallery.uploads')}
</IAIButton>
</>
)}
</ButtonGroup> </ButtonGroup>
<Flex gap={2}> <Flex gap={2}>
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
size="sm" tooltip={t('gallery.gallerySettings')}
aria-label={t('gallery.gallerySettings')} aria-label={t('gallery.gallerySettings')}
size="sm"
icon={<FaWrench />} icon={<FaWrench />}
/> />
} }
@ -347,28 +293,17 @@ const ImageGalleryContent = () => {
data={images} data={images}
endReached={handleEndReached} endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)} scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, image) => { itemContent={(index, image) => (
const isSelected = <Flex sx={{ pb: 2 }}>
image === PROGRESS_IMAGE_PLACEHOLDER <HoverableImage
? false key={`${image.image_name}-${image.thumbnail_url}`}
: selectedImage?.image_name === image?.image_name; image={image}
isSelected={
return ( selectedImage?.image_name === image?.image_name
<Flex sx={{ pb: 2 }}> }
{image === PROGRESS_IMAGE_PLACEHOLDER ? ( />
<GalleryProgressImage </Flex>
key={PROGRESS_IMAGE_PLACEHOLDER} )}
/>
) : (
<HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`}
image={image}
isSelected={isSelected}
/>
)}
</Flex>
);
}}
/> />
) : ( ) : (
<VirtuosoGrid <VirtuosoGrid
@ -380,27 +315,20 @@ const ImageGalleryContent = () => {
List: ListContainer, List: ListContainer,
}} }}
scrollerRef={setScroller} scrollerRef={setScroller}
itemContent={(index, image) => { itemContent={(index, image) => (
const isSelected = <HoverableImage
image === PROGRESS_IMAGE_PLACEHOLDER key={`${image.image_name}-${image.thumbnail_url}`}
? false image={image}
: selectedImage?.image_name === image?.image_name; isSelected={
selectedImage?.image_name === image?.image_name
return image === PROGRESS_IMAGE_PLACEHOLDER ? ( }
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} /> />
) : ( )}
<HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`}
image={image}
isSelected={isSelected}
/>
);
}}
/> />
)} )}
</Box> </Box>
<IAIButton <IAIButton
onClick={handleClickLoadMore} onClick={handleLoadMoreImages}
isDisabled={!areMoreImagesAvailable} isDisabled={!areMoreImagesAvailable}
isLoading={isLoading} isLoading={isLoading}
loadingText="Loading" loadingText="Loading"

View File

@ -53,6 +53,11 @@ const MetadataItem = ({
withCopy = false, withCopy = false,
}: MetadataItemProps) => { }: MetadataItemProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
if (!value) {
return null;
}
return ( return (
<Flex gap={2}> <Flex gap={2}>
{onClick && ( {onClick && (

View File

@ -9,6 +9,10 @@ import { gallerySelector } from '../store/gallerySelectors';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { imageSelected } from '../store/gallerySlice'; import { imageSelected } from '../store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import {
selectFilteredImagesAsObject,
selectFilteredImagesIds,
} from '../store/imagesSlice';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%', height: '100%',
@ -21,9 +25,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
}; };
export const nextPrevImageButtonsSelector = createSelector( export const nextPrevImageButtonsSelector = createSelector(
[(state: RootState) => state, gallerySelector], [
(state, gallery) => { (state: RootState) => state,
const { selectedImage, currentCategory } = gallery; gallerySelector,
selectFilteredImagesAsObject,
selectFilteredImagesIds,
],
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
const { selectedImage } = gallery;
if (!selectedImage) { if (!selectedImage) {
return { return {
@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector(
}; };
} }
const currentImageIndex = state[currentCategory].ids.findIndex( const currentImageIndex = filteredImageIds.findIndex(
(i) => i === selectedImage.image_name (i) => i === selectedImage.image_name
); );
const nextImageIndex = clamp( const nextImageIndex = clamp(
currentImageIndex + 1, currentImageIndex + 1,
0, 0,
state[currentCategory].ids.length - 1 filteredImageIds.length - 1
); );
const prevImageIndex = clamp( const prevImageIndex = clamp(
currentImageIndex - 1, currentImageIndex - 1,
0, 0,
state[currentCategory].ids.length - 1 filteredImageIds.length - 1
); );
const nextImageId = state[currentCategory].ids[nextImageIndex]; const nextImageId = filteredImageIds[nextImageIndex];
const prevImageId = state[currentCategory].ids[prevImageIndex]; const prevImageId = filteredImageIds[prevImageIndex];
const nextImage = state[currentCategory].entities[nextImageId]; const nextImage = filteredImagesAsObject[nextImageId];
const prevImage = state[currentCategory].entities[prevImageId]; const prevImage = filteredImagesAsObject[prevImageId];
const imagesLength = state[currentCategory].ids.length; const imagesLength = filteredImageIds.length;
return { return {
isOnFirstImage: currentImageIndex === 0, isOnFirstImage: currentImageIndex === 0,

View File

@ -1,33 +1,18 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { ImageType } from 'services/api'; import { selectImagesEntities } from '../store/imagesSlice';
import { selectResultsEntities } from '../store/resultsSlice'; import { useCallback } from 'react';
import { selectUploadsEntities } from '../store/uploadsSlice';
const useGetImageByNameSelector = createSelector( const useGetImageByName = () => {
[selectResultsEntities, selectUploadsEntities], const images = useAppSelector(selectImagesEntities);
(allResults, allUploads) => { return useCallback(
return { allResults, allUploads }; (name: string | undefined) => {
} if (!name) {
); return;
const useGetImageByNameAndType = () => {
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
return (name: string, type: ImageType) => {
if (type === 'results') {
const resultImagesResult = allResults[name];
if (resultImagesResult) {
return resultImagesResult;
} }
} return images[name];
},
if (type === 'uploads') { [images]
const userImagesResult = allUploads[name]; );
if (userImagesResult) {
return userImagesResult;
}
}
};
}; };
export default useGetImageByNameAndType; export default useGetImageByName;

View File

@ -1,9 +1,9 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageNameAndType } from 'features/parameters/store/actions'; import { ImageNameAndOrigin } from 'features/parameters/store/actions';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
export const requestedImageDeletion = createAction< export const requestedImageDeletion = createAction<
ImageDTO | ImageNameAndType | undefined ImageDTO | ImageNameAndOrigin | undefined
>('gallery/requestedImageDeletion'); >('gallery/requestedImageDeletion');
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas'); export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');

View File

@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist * Gallery slice persist denylist
*/ */
export const galleryPersistDenylist: (keyof GalleryState)[] = [ export const galleryPersistDenylist: (keyof GalleryState)[] = [
'currentCategory',
'shouldAutoSwitchToNewImages', 'shouldAutoSwitchToNewImages',
]; ];

View File

@ -1,10 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from '../../../services/thunks/gallery';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { imageUpserted } from './imagesSlice';
type GalleryImageObjectFitType = 'contain' | 'cover'; type GalleryImageObjectFitType = 'contain' | 'cover';
@ -14,7 +11,6 @@ export interface GalleryState {
galleryImageObjectFit: GalleryImageObjectFitType; galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean; shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean; shouldUseSingleGalleryColumn: boolean;
currentCategory: 'results' | 'uploads';
} }
export const initialGalleryState: GalleryState = { export const initialGalleryState: GalleryState = {
@ -22,7 +18,6 @@ export const initialGalleryState: GalleryState = {
galleryImageObjectFit: 'cover', galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true, shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false, shouldUseSingleGalleryColumn: false,
currentCategory: 'results',
}; };
export const gallerySlice = createSlice({ export const gallerySlice = createSlice({
@ -46,12 +41,6 @@ export const gallerySlice = createSlice({
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => { setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitchToNewImages = action.payload; state.shouldAutoSwitchToNewImages = action.payload;
}, },
setCurrentCategory: (
state,
action: PayloadAction<'results' | 'uploads'>
) => {
state.currentCategory = action.payload;
},
setShouldUseSingleGalleryColumn: ( setShouldUseSingleGalleryColumn: (
state, state,
action: PayloadAction<boolean> action: PayloadAction<boolean>
@ -59,37 +48,10 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload; state.shouldUseSingleGalleryColumn = action.payload;
}, },
}, },
extraReducers(builder) { extraReducers: (builder) => {
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => { builder.addCase(imageUpserted, (state, action) => {
// rehydrate selectedImage URL when results list comes in if (state.shouldAutoSwitchToNewImages) {
// solves case when outdated URL is in local storage state.selectedImage = action.payload;
const selectedImage = state.selectedImage;
if (selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === selectedImage.image_name
);
if (selectedImageInResults) {
selectedImage.image_url = selectedImageInResults.image_url;
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
state.selectedImage = selectedImage;
}
}
});
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
const selectedImage = state.selectedImage;
if (selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === selectedImage.image_name
);
if (selectedImageInResults) {
selectedImage.image_url = selectedImageInResults.image_url;
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
state.selectedImage = selectedImage;
}
} }
}); });
}, },
@ -101,7 +63,6 @@ export const {
setGalleryImageObjectFit, setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages, setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn, setShouldUseSingleGalleryColumn,
setCurrentCategory,
} = gallerySlice.actions; } = gallerySlice.actions;
export default gallerySlice.reducer; export default gallerySlice.reducer;

View File

@ -0,0 +1,135 @@
import {
PayloadAction,
createEntityAdapter,
createSelector,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageCategory, ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
import { isString, keyBy } from 'lodash-es';
import { receivedPageOfImages } from 'services/thunks/image';
export const imagesAdapter = createEntityAdapter<ImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = [
'control',
'mask',
'user',
'other',
];
type AdditionaImagesState = {
offset: number;
limit: number;
total: number;
isLoading: boolean;
categories: ImageCategory[];
};
export const initialImagesState =
imagesAdapter.getInitialState<AdditionaImagesState>({
offset: 0,
limit: 0,
total: 0,
isLoading: false,
categories: IMAGE_CATEGORIES,
});
export type ImagesState = typeof initialImagesState;
const imagesSlice = createSlice({
name: 'images',
initialState: initialImagesState,
reducers: {
imageUpserted: (state, action: PayloadAction<ImageDTO>) => {
imagesAdapter.upsertOne(state, action.payload);
},
imageRemoved: (state, action: PayloadAction<string | ImageDTO>) => {
if (isString(action.payload)) {
imagesAdapter.removeOne(state, action.payload);
return;
}
imagesAdapter.removeOne(state, action.payload.image_name);
},
imageCategoriesChanged: (state, action: PayloadAction<ImageCategory[]>) => {
state.categories = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(receivedPageOfImages.pending, (state) => {
state.isLoading = true;
});
builder.addCase(receivedPageOfImages.rejected, (state) => {
state.isLoading = false;
});
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false;
const { items, offset, limit, total } = action.payload;
state.offset = offset;
state.limit = limit;
state.total = total;
imagesAdapter.upsertMany(state, items);
});
},
});
export const {
selectAll: selectImagesAll,
selectById: selectImagesById,
selectEntities: selectImagesEntities,
selectIds: selectImagesIds,
selectTotal: selectImagesTotal,
} = imagesAdapter.getSelectors<RootState>((state) => state.images);
export const { imageUpserted, imageRemoved, imageCategoriesChanged } =
imagesSlice.actions;
export default imagesSlice.reducer;
export const selectFilteredImagesAsArray = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return selectImagesAll(state).filter((i) =>
categories.includes(i.image_category)
);
}
);
export const selectFilteredImagesAsObject = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return keyBy(
selectImagesAll(state).filter((i) =>
categories.includes(i.image_category)
),
'image_name'
);
}
);
export const selectFilteredImagesIds = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return selectImagesAll(state)
.filter((i) => categories.includes(i.image_category))
.map((i) => i.image_name);
}
);

View File

@ -1,8 +0,0 @@
import { ResultsState } from './resultsSlice';
/**
* Results slice persist denylist
*
* Currently denylisting results slice entirely, see `serialize.ts`
*/
export const resultsPersistDenylist: (keyof ResultsState)[] = [];

View File

@ -1,88 +0,0 @@
import {
PayloadAction,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'results';
};
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
type AdditionalResultsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
upsertedImageCount: number;
};
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
upsertedImageCount: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
resultUpserted: (state, action: PayloadAction<ResultsImageDTO>) => {
resultsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1;
},
},
extraReducers: (builder) => {
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { page, pages } = action.payload;
// We know these will all be of the results type, but it's not represented in the API types
const items = action.payload.items as ResultsImageDTO[];
resultsAdapter.setMany(state, items);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
},
});
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultUpserted } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@ -1,8 +0,0 @@
import { UploadsState } from './uploadsSlice';
/**
* Uploads slice persist denylist
*
* Currently denylisting uploads slice entirely, see `serialize.ts`
*/
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];

View File

@ -1,89 +0,0 @@
import {
PayloadAction,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'uploads';
};
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
upsertedImageCount: number;
};
export const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
upsertedImageCount: 0,
});
export type UploadsState = typeof initialUploadsState;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadUpserted: (state, action: PayloadAction<UploadsImageDTO>) => {
uploadsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1;
},
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { page, pages } = action.payload;
// We know these will all be of the uploads type, but it's not represented in the API types
const items = action.payload.items as UploadsImageDTO[];
uploadsAdapter.setMany(state, items);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadUpserted } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder'; import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName'; import useGetImageByName from 'features/gallery/hooks/useGetImageByName';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -11,7 +11,6 @@ import {
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { DragEvent, memo, useCallback, useState } from 'react'; import { DragEvent, memo, useCallback, useState } from 'react';
import { ImageType } from 'services/api';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
@ -19,7 +18,7 @@ const ImageInputFieldComponent = (
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const getImageByNameAndType = useGetImageByNameAndType(); const getImageByName = useGetImageByName();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [url, setUrl] = useState<string | undefined>(field.value?.image_url); const [url, setUrl] = useState<string | undefined>(field.value?.image_url);
const { getUrl } = useGetUrl(); const { getUrl } = useGetUrl();
@ -27,13 +26,7 @@ const ImageInputFieldComponent = (
const handleDrop = useCallback( const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
const name = e.dataTransfer.getData('invokeai/imageName'); const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; const image = getImageByName(name);
if (!name || !type) {
return;
}
const image = getImageByNameAndType(name, type);
if (!image) { if (!image) {
return; return;
@ -49,7 +42,7 @@ const ImageInputFieldComponent = (
}) })
); );
}, },
[getImageByNameAndType, dispatch, field.name, nodeId] [getImageByName, dispatch, field.name, nodeId]
); );
return ( return (

View File

@ -26,18 +26,21 @@ const buildBaseNode = (
| ImageToImageInvocation | ImageToImageInvocation
| InpaintInvocation | InpaintInvocation
| undefined => { | undefined => {
const dimensionsOverride = state.canvas.boundingBoxDimensions; const overrides = {
...state.canvas.boundingBoxDimensions,
is_intermediate: true,
};
if (nodeType === 'txt2img') { if (nodeType === 'txt2img') {
return buildTxt2ImgNode(state, dimensionsOverride); return buildTxt2ImgNode(state, overrides);
} }
if (nodeType === 'img2img') { if (nodeType === 'img2img') {
return buildImg2ImgNode(state, dimensionsOverride); return buildImg2ImgNode(state, overrides);
} }
if (nodeType === 'inpaint' || nodeType === 'outpaint') { if (nodeType === 'inpaint' || nodeType === 'outpaint') {
return buildInpaintNode(state, dimensionsOverride); return buildInpaintNode(state, overrides);
} }
}; };

View File

@ -64,7 +64,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
model, model,
image: { image: {
image_name: initialImage?.image_name, image_name: initialImage?.image_name,
image_type: initialImage?.image_type, image_origin: initialImage?.image_origin,
}, },
}; };

View File

@ -58,7 +58,7 @@ export const buildImg2ImgNode = (
imageToImageNode.image = { imageToImageNode.image = {
image_name: initialImage.name, image_name: initialImage.name,
image_type: initialImage.type, image_origin: initialImage.type,
}; };
} }

View File

@ -51,7 +51,7 @@ export const buildInpaintNode = (
inpaintNode.image = { inpaintNode.image = {
image_name: initialImage.name, image_name: initialImage.name,
image_type: initialImage.type, image_origin: initialImage.type,
}; };
} }

View File

@ -13,7 +13,7 @@ import {
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'meta']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
const invocationDenylist = ['Graph', 'InvocationMeta']; const invocationDenylist = ['Graph', 'InvocationMeta'];

View File

@ -15,7 +15,7 @@ const ParamInfillCollapse = () => {
return ( return (
<IAICollapse <IAICollapse
label={t('parameters.boundingBoxHeader')} label={t('parameters.infillScalingHeader')}
isOpen={isOpen} isOpen={isOpen}
onToggle={onToggle} onToggle={onToggle}
> >

View File

@ -5,7 +5,6 @@ import { useGetUrl } from 'common/util/getUrl';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { DragEvent, useCallback } from 'react'; import { DragEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageType } from 'services/api';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
@ -55,9 +54,7 @@ const InitialImagePreview = () => {
const handleDrop = useCallback( const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
const name = e.dataTransfer.getData('invokeai/imageName'); const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; dispatch(initialImageSelected(name));
dispatch(initialImageSelected({ image_name: name, image_type: type }));
}, },
[dispatch] [dispatch]
); );

View File

@ -88,7 +88,7 @@ export const useParameters = () => {
return; return;
} }
dispatch(initialImageSelected(image)); dispatch(initialImageSelected(image.image_name));
toaster({ toaster({
title: t('toast.initialImageSet'), title: t('toast.initialImageSet'),
status: 'info', status: 'info',

View File

@ -1,10 +1,10 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { isObject } from 'lodash-es'; import { isObject } from 'lodash-es';
import { ImageDTO, ImageType } from 'services/api'; import { ImageDTO, ResourceOrigin } from 'services/api';
export type ImageNameAndType = { export type ImageNameAndOrigin = {
image_name: string; image_name: string;
image_type: ImageType; image_origin: ResourceOrigin;
}; };
export const isImageDTO = (image: any): image is ImageDTO => { export const isImageDTO = (image: any): image is ImageDTO => {
@ -13,8 +13,8 @@ export const isImageDTO = (image: any): image is ImageDTO => {
isObject(image) && isObject(image) &&
'image_name' in image && 'image_name' in image &&
image?.image_name !== undefined && image?.image_name !== undefined &&
'image_type' in image && 'image_origin' in image &&
image?.image_type !== undefined && image?.image_origin !== undefined &&
'image_url' in image && 'image_url' in image &&
image?.image_url !== undefined && image?.image_url !== undefined &&
'thumbnail_url' in image && 'thumbnail_url' in image &&
@ -26,6 +26,6 @@ export const isImageDTO = (image: any): image is ImageDTO => {
); );
}; };
export const initialImageSelected = createAction< export const initialImageSelected = createAction<ImageDTO | string | undefined>(
ImageDTO | ImageNameAndType | undefined 'generation/initialImageSelected'
>('generation/initialImageSelected'); );

View File

@ -1,34 +1,3 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { isEqual } from 'lodash-es';
export const generationSelector = (state: RootState) => state.generation; export const generationSelector = (state: RootState) => state.generation;
export const mayGenerateMultipleImagesSelector = createSelector(
generationSelector,
({ shouldRandomizeSeed, shouldGenerateVariations }) => {
return shouldRandomizeSeed || shouldGenerateVariations;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const initialImageSelector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const { initialImage } = generation;
if (initialImage?.type === 'results') {
return selectResultsById(state, initialImage.name);
}
if (initialImage?.type === 'uploads') {
return selectUploadsById(state, initialImage.name);
}
}
);

View File

@ -2,17 +2,6 @@ import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, isAnyOf } from '@reduxjs/toolkit'; import { PayloadAction, isAnyOf } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import {
generatorProgress,
graphExecutionStateComplete,
invocationComplete,
invocationError,
invocationStarted,
socketConnected,
socketDisconnected,
socketSubscribed,
socketUnsubscribed,
} from 'services/events/actions';
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
@ -30,6 +19,17 @@ import { t } from 'i18next';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { LANGUAGES } from '../components/LanguagePicker'; import { LANGUAGES } from '../components/LanguagePicker';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import {
appSocketConnected,
appSocketDisconnected,
appSocketGeneratorProgress,
appSocketGraphExecutionStateComplete,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationStarted,
appSocketSubscribed,
appSocketUnsubscribed,
} from 'services/events/actions';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -227,7 +227,7 @@ export const systemSlice = createSlice({
/** /**
* Socket Subscribed * Socket Subscribed
*/ */
builder.addCase(socketSubscribed, (state, action) => { builder.addCase(appSocketSubscribed, (state, action) => {
state.sessionId = action.payload.sessionId; state.sessionId = action.payload.sessionId;
state.canceledSession = ''; state.canceledSession = '';
}); });
@ -235,14 +235,14 @@ export const systemSlice = createSlice({
/** /**
* Socket Unsubscribed * Socket Unsubscribed
*/ */
builder.addCase(socketUnsubscribed, (state) => { builder.addCase(appSocketUnsubscribed, (state) => {
state.sessionId = null; state.sessionId = null;
}); });
/** /**
* Socket Connected * Socket Connected
*/ */
builder.addCase(socketConnected, (state) => { builder.addCase(appSocketConnected, (state) => {
state.isConnected = true; state.isConnected = true;
state.isCancelable = true; state.isCancelable = true;
state.isProcessing = false; state.isProcessing = false;
@ -257,7 +257,7 @@ export const systemSlice = createSlice({
/** /**
* Socket Disconnected * Socket Disconnected
*/ */
builder.addCase(socketDisconnected, (state) => { builder.addCase(appSocketDisconnected, (state) => {
state.isConnected = false; state.isConnected = false;
state.isProcessing = false; state.isProcessing = false;
state.isCancelable = true; state.isCancelable = true;
@ -272,7 +272,7 @@ export const systemSlice = createSlice({
/** /**
* Invocation Started * Invocation Started
*/ */
builder.addCase(invocationStarted, (state) => { builder.addCase(appSocketInvocationStarted, (state) => {
state.isCancelable = true; state.isCancelable = true;
state.isProcessing = true; state.isProcessing = true;
state.currentStatusHasSteps = false; state.currentStatusHasSteps = false;
@ -286,7 +286,7 @@ export const systemSlice = createSlice({
/** /**
* Generator Progress * Generator Progress
*/ */
builder.addCase(generatorProgress, (state, action) => { builder.addCase(appSocketGeneratorProgress, (state, action) => {
const { step, total_steps, progress_image } = action.payload.data; const { step, total_steps, progress_image } = action.payload.data;
state.isProcessing = true; state.isProcessing = true;
@ -303,7 +303,7 @@ export const systemSlice = createSlice({
/** /**
* Invocation Complete * Invocation Complete
*/ */
builder.addCase(invocationComplete, (state, action) => { builder.addCase(appSocketInvocationComplete, (state, action) => {
const { data } = action.payload; const { data } = action.payload;
// state.currentIteration = 0; // state.currentIteration = 0;
@ -322,7 +322,7 @@ export const systemSlice = createSlice({
/** /**
* Invocation Error * Invocation Error
*/ */
builder.addCase(invocationError, (state) => { builder.addCase(appSocketInvocationError, (state) => {
state.isProcessing = false; state.isProcessing = false;
state.isCancelable = true; state.isCancelable = true;
// state.currentIteration = 0; // state.currentIteration = 0;
@ -339,7 +339,20 @@ export const systemSlice = createSlice({
}); });
/** /**
* Session Invoked - PENDING * Graph Execution State Complete
*/
builder.addCase(appSocketGraphExecutionStateComplete, (state) => {
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusConnected';
state.progressImage = null;
});
/**
* User Invoked
*/ */
builder.addCase(userInvoked, (state) => { builder.addCase(userInvoked, (state) => {
@ -367,18 +380,6 @@ export const systemSlice = createSlice({
); );
}); });
/**
* Session Canceled
*/
builder.addCase(graphExecutionStateComplete, (state) => {
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusConnected';
});
/** /**
* Received available models from the backend * Received available models from the backend
*/ */

View File

@ -8,6 +8,7 @@ export type { OpenAPIConfig } from './core/OpenAPI';
export type { AddInvocation } from './models/AddInvocation'; export type { AddInvocation } from './models/AddInvocation';
export type { Body_upload_image } from './models/Body_upload_image'; export type { Body_upload_image } from './models/Body_upload_image';
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
export type { CkptModelInfo } from './models/CkptModelInfo'; export type { CkptModelInfo } from './models/CkptModelInfo';
export type { CollectInvocation } from './models/CollectInvocation'; export type { CollectInvocation } from './models/CollectInvocation';
export type { CollectInvocationOutput } from './models/CollectInvocationOutput'; export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
@ -15,16 +16,23 @@ export type { ColorField } from './models/ColorField';
export type { CompelInvocation } from './models/CompelInvocation'; export type { CompelInvocation } from './models/CompelInvocation';
export type { CompelOutput } from './models/CompelOutput'; export type { CompelOutput } from './models/CompelOutput';
export type { ConditioningField } from './models/ConditioningField'; export type { ConditioningField } from './models/ConditioningField';
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
export type { ControlField } from './models/ControlField';
export type { ControlNetInvocation } from './models/ControlNetInvocation';
export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation'; export type { DivideInvocation } from './models/DivideInvocation';
export type { Edge } from './models/Edge'; export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection'; export type { EdgeConnection } from './models/EdgeConnection';
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
export type { FloatOutput } from './models/FloatOutput';
export type { Graph } from './models/Graph'; export type { Graph } from './models/Graph';
export type { GraphExecutionState } from './models/GraphExecutionState'; export type { GraphExecutionState } from './models/GraphExecutionState';
export type { GraphInvocation } from './models/GraphInvocation'; export type { GraphInvocation } from './models/GraphInvocation';
export type { GraphInvocationOutput } from './models/GraphInvocationOutput'; export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
export type { HedImageprocessorInvocation } from './models/HedImageprocessorInvocation';
export type { HTTPValidationError } from './models/HTTPValidationError'; export type { HTTPValidationError } from './models/HTTPValidationError';
export type { ImageBlurInvocation } from './models/ImageBlurInvocation'; export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
export type { ImageCategory } from './models/ImageCategory'; export type { ImageCategory } from './models/ImageCategory';
@ -39,10 +47,10 @@ export type { ImageMetadata } from './models/ImageMetadata';
export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation'; export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
export type { ImageOutput } from './models/ImageOutput'; export type { ImageOutput } from './models/ImageOutput';
export type { ImagePasteInvocation } from './models/ImagePasteInvocation'; export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation';
export type { ImageRecordChanges } from './models/ImageRecordChanges'; export type { ImageRecordChanges } from './models/ImageRecordChanges';
export type { ImageToImageInvocation } from './models/ImageToImageInvocation'; export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation'; export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
export type { ImageType } from './models/ImageType';
export type { ImageUrlsDTO } from './models/ImageUrlsDTO'; export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
export type { InfillColorInvocation } from './models/InfillColorInvocation'; export type { InfillColorInvocation } from './models/InfillColorInvocation';
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation'; export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
@ -56,22 +64,32 @@ export type { LatentsField } from './models/LatentsField';
export type { LatentsOutput } from './models/LatentsOutput'; export type { LatentsOutput } from './models/LatentsOutput';
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation'; export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation'; export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation';
export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput'; export type { MaskOutput } from './models/MaskOutput';
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation';
export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation';
export type { ModelsList } from './models/ModelsList'; export type { ModelsList } from './models/ModelsList';
export type { MultiplyInvocation } from './models/MultiplyInvocation'; export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation'; export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput'; export type { NoiseOutput } from './models/NoiseOutput';
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_'; export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PromptOutput } from './models/PromptOutput'; export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
export type { RangeInvocation } from './models/RangeInvocation'; export type { RangeInvocation } from './models/RangeInvocation';
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation'; export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation'; export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation';
@ -81,6 +99,7 @@ export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeRepo } from './models/VaeRepo'; export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError'; export type { ValidationError } from './models/ValidationError';
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
export { ImagesService } from './services/ImagesService'; export { ImagesService } from './services/ImagesService';
export { ModelsService } from './services/ModelsService'; export { ModelsService } from './services/ModelsService';

View File

@ -12,6 +12,10 @@ export type CannyImageProcessorInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'canny_image_processor'; type?: 'canny_image_processor';
/** /**
* image to process * image to process

View File

@ -12,6 +12,10 @@ export type ContentShuffleImageProcessorInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'content_shuffle_image_processor'; type?: 'content_shuffle_image_processor';
/** /**
* image to process * image to process

View File

@ -12,6 +12,10 @@ export type ControlNetInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'controlnet'; type?: 'controlnet';
/** /**
* image to process * image to process
@ -20,7 +24,7 @@ export type ControlNetInvocation = {
/** /**
* control model used * control model used
*/ */
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace'; control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
/** /**
* weight given to controlnet * weight given to controlnet
*/ */

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* A collection of floats
*/
export type FloatCollectionOutput = {
type?: 'float_collection';
/**
* The float collection
*/
collection?: Array<number>;
};

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* A float output
*/
export type FloatOutput = {
type?: 'float_output';
/**
* The output float
*/
param?: number;
};

View File

@ -3,12 +3,16 @@
/* eslint-disable */ /* eslint-disable */
import type { AddInvocation } from './AddInvocation'; import type { AddInvocation } from './AddInvocation';
import type { CannyImageProcessorInvocation } from './CannyImageProcessorInvocation';
import type { CollectInvocation } from './CollectInvocation'; import type { CollectInvocation } from './CollectInvocation';
import type { CompelInvocation } from './CompelInvocation'; import type { CompelInvocation } from './CompelInvocation';
import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleImageProcessorInvocation';
import type { ControlNetInvocation } from './ControlNetInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation'; import type { DivideInvocation } from './DivideInvocation';
import type { Edge } from './Edge'; import type { Edge } from './Edge';
import type { GraphInvocation } from './GraphInvocation'; import type { GraphInvocation } from './GraphInvocation';
import type { HedImageprocessorInvocation } from './HedImageprocessorInvocation';
import type { ImageBlurInvocation } from './ImageBlurInvocation'; import type { ImageBlurInvocation } from './ImageBlurInvocation';
import type { ImageChannelInvocation } from './ImageChannelInvocation'; import type { ImageChannelInvocation } from './ImageChannelInvocation';
import type { ImageConvertInvocation } from './ImageConvertInvocation'; import type { ImageConvertInvocation } from './ImageConvertInvocation';
@ -17,6 +21,7 @@ import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation';
import type { ImageLerpInvocation } from './ImageLerpInvocation'; import type { ImageLerpInvocation } from './ImageLerpInvocation';
import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation'; import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation';
import type { ImagePasteInvocation } from './ImagePasteInvocation'; import type { ImagePasteInvocation } from './ImagePasteInvocation';
import type { ImageProcessorInvocation } from './ImageProcessorInvocation';
import type { ImageToImageInvocation } from './ImageToImageInvocation'; import type { ImageToImageInvocation } from './ImageToImageInvocation';
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation'; import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
import type { InfillColorInvocation } from './InfillColorInvocation'; import type { InfillColorInvocation } from './InfillColorInvocation';
@ -26,11 +31,20 @@ import type { InpaintInvocation } from './InpaintInvocation';
import type { IterateInvocation } from './IterateInvocation'; import type { IterateInvocation } from './IterateInvocation';
import type { LatentsToImageInvocation } from './LatentsToImageInvocation'; import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation'; import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
import type { LineartAnimeImageProcessorInvocation } from './LineartAnimeImageProcessorInvocation';
import type { LineartImageProcessorInvocation } from './LineartImageProcessorInvocation';
import type { LoadImageInvocation } from './LoadImageInvocation'; import type { LoadImageInvocation } from './LoadImageInvocation';
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation'; import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
import type { MediapipeFaceProcessorInvocation } from './MediapipeFaceProcessorInvocation';
import type { MidasDepthImageProcessorInvocation } from './MidasDepthImageProcessorInvocation';
import type { MlsdImageProcessorInvocation } from './MlsdImageProcessorInvocation';
import type { MultiplyInvocation } from './MultiplyInvocation'; import type { MultiplyInvocation } from './MultiplyInvocation';
import type { NoiseInvocation } from './NoiseInvocation'; import type { NoiseInvocation } from './NoiseInvocation';
import type { NormalbaeImageProcessorInvocation } from './NormalbaeImageProcessorInvocation';
import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorInvocation';
import type { ParamFloatInvocation } from './ParamFloatInvocation';
import type { ParamIntInvocation } from './ParamIntInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation';
import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation';
import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation';
import type { RangeInvocation } from './RangeInvocation'; import type { RangeInvocation } from './RangeInvocation';
@ -43,6 +57,7 @@ import type { SubtractInvocation } from './SubtractInvocation';
import type { TextToImageInvocation } from './TextToImageInvocation'; import type { TextToImageInvocation } from './TextToImageInvocation';
import type { TextToLatentsInvocation } from './TextToLatentsInvocation'; import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
import type { UpscaleInvocation } from './UpscaleInvocation'; import type { UpscaleInvocation } from './UpscaleInvocation';
import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation';
export type Graph = { export type Graph = {
/** /**
@ -52,7 +67,7 @@ export type Graph = {
/** /**
* The nodes in this graph * The nodes in this graph
*/ */
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>; nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/** /**
* The connections between nodes and their fields in this graph * The connections between nodes and their fields in this graph
*/ */

View File

@ -4,6 +4,9 @@
import type { CollectInvocationOutput } from './CollectInvocationOutput'; import type { CollectInvocationOutput } from './CollectInvocationOutput';
import type { CompelOutput } from './CompelOutput'; import type { CompelOutput } from './CompelOutput';
import type { ControlOutput } from './ControlOutput';
import type { FloatCollectionOutput } from './FloatCollectionOutput';
import type { FloatOutput } from './FloatOutput';
import type { Graph } from './Graph'; import type { Graph } from './Graph';
import type { GraphInvocationOutput } from './GraphInvocationOutput'; import type { GraphInvocationOutput } from './GraphInvocationOutput';
import type { ImageOutput } from './ImageOutput'; import type { ImageOutput } from './ImageOutput';
@ -42,7 +45,7 @@ export type GraphExecutionState = {
/** /**
* The results of node executions * The results of node executions
*/ */
results: Record<string, (ImageOutput | MaskOutput | PromptOutput | CompelOutput | IntOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>; results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
/** /**
* Errors raised when executing nodes * Errors raised when executing nodes
*/ */

View File

@ -12,6 +12,10 @@ export type HedImageprocessorInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'hed_image_processor'; type?: 'hed_image_processor';
/** /**
* image to process * image to process

View File

@ -3,6 +3,12 @@
/* eslint-disable */ /* eslint-disable */
/** /**
* The category of an image. Use ImageCategory.OTHER for non-default categories. * The category of an image.
*
* - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
* - MASK: The image is a mask image.
* - CONTROL: The image is a ControlNet control image.
* - USER: The image is a user-provide image.
* - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
*/ */
export type ImageCategory = 'general' | 'control' | 'mask' | 'other'; export type ImageCategory = 'general' | 'mask' | 'control' | 'user' | 'other';

View File

@ -4,7 +4,7 @@
import type { ImageCategory } from './ImageCategory'; import type { ImageCategory } from './ImageCategory';
import type { ImageMetadata } from './ImageMetadata'; import type { ImageMetadata } from './ImageMetadata';
import type { ImageType } from './ImageType'; import type { ResourceOrigin } from './ResourceOrigin';
/** /**
* Deserialized image record, enriched for the frontend with URLs. * Deserialized image record, enriched for the frontend with URLs.
@ -17,7 +17,7 @@ export type ImageDTO = {
/** /**
* The type of the image. * The type of the image.
*/ */
image_type: ImageType; image_origin: ResourceOrigin;
/** /**
* The URL of the image. * The URL of the image.
*/ */

View File

@ -2,7 +2,7 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { ImageType } from './ImageType'; import type { ResourceOrigin } from './ResourceOrigin';
/** /**
* An image field used for passing image objects between invocations * An image field used for passing image objects between invocations
@ -11,7 +11,7 @@ export type ImageField = {
/** /**
* The type of the image * The type of the image
*/ */
image_type: ImageType; image_origin: ResourceOrigin;
/** /**
* The name of the image * The name of the image
*/ */

View File

@ -12,6 +12,10 @@ export type ImageProcessorInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'image_processor'; type?: 'image_processor';
/** /**
* image to process * image to process

View File

@ -10,6 +10,7 @@ import type { ImageCategory } from './ImageCategory';
* Only limited changes are valid: * Only limited changes are valid:
* - `image_category`: change the category of an image * - `image_category`: change the category of an image
* - `session_id`: change the session associated with an image * - `session_id`: change the session associated with an image
* - `is_intermediate`: change the image's `is_intermediate` flag
*/ */
export type ImageRecordChanges = { export type ImageRecordChanges = {
/** /**
@ -20,5 +21,9 @@ export type ImageRecordChanges = {
* The image's new session ID. * The image's new session ID.
*/ */
session_id?: string; session_id?: string;
/**
* The image's new `is_intermediate` flag.
*/
is_intermediate?: boolean;
}; };

Some files were not shown because too many files have changed in this diff Show More