mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into release/make-web-dist-startable
This commit is contained in:
commit
dc54cbb1fc
@ -5,6 +5,7 @@ import os
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -65,7 +66,7 @@ class ApiDependencies:
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
@ -76,6 +77,7 @@ class ApiDependencies:
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
|
@ -1,39 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
# description="The image's InvokeAI-specific metadata"
|
||||
# )
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class SavedImage(BaseModel):
|
||||
image_name: str = Field(description="The name of the saved image")
|
||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
||||
created: int = Field(description="The created timestamp of the saved image")
|
@ -6,8 +6,9 @@ from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
@ -34,12 +35,8 @@ async def upload_image(
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(
|
||||
default=ImageCategory.GENERAL, description="The category of the image"
|
||||
),
|
||||
is_intermediate: bool = Query(
|
||||
default=False, description="Whether this is an intermediate image"
|
||||
),
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
@ -59,7 +56,7 @@ async def upload_image(
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
image=pil_image,
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_origin=ResourceOrigin.EXTERNAL,
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
is_intermediate=is_intermediate,
|
||||
@ -73,27 +70,27 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_type: ImageType = Path(description="The type of image to delete"),
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_type, image_name)
|
||||
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_type}/{image_name}",
|
||||
"/{image_origin}/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_type: ImageType = Path(description="The type of image to update"),
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
@ -103,31 +100,31 @@ async def update_image(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(
|
||||
image_type, image_name, image_changes
|
||||
image_origin, image_name, image_changes
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_type}/{image_name}/metadata",
|
||||
"/{image_origin}/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_type}/{image_name}",
|
||||
"/{image_origin}/{image_name}",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@ -139,7 +136,7 @@ async def get_image_metadata(
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_type: ImageType = Path(
|
||||
image_origin: ResourceOrigin = Path(
|
||||
description="The type of full-resolution image file to get"
|
||||
),
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
@ -147,7 +144,7 @@ async def get_image_full(
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
@ -163,7 +160,7 @@ async def get_image_full(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_type}/{image_name}/thumbnail",
|
||||
"/{image_origin}/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@ -175,14 +172,14 @@ async def get_image_full(
|
||||
},
|
||||
)
|
||||
async def get_image_thumbnail(
|
||||
image_type: ImageType = Path(description="The type of thumbnail image file to get"),
|
||||
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type, image_name, thumbnail=True
|
||||
image_origin, image_name, thumbnail=True
|
||||
)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
@ -195,25 +192,25 @@ async def get_image_thumbnail(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_type}/{image_name}/urls",
|
||||
"/{image_origin}/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
image_type: ImageType = Path(description="The type of the image whose URL to get"),
|
||||
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_type, image_name
|
||||
image_origin, image_name
|
||||
)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_type, image_name, thumbnail=True
|
||||
image_origin, image_name, thumbnail=True
|
||||
)
|
||||
return ImageUrlsDTO(
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
@ -225,23 +222,29 @@ async def get_image_urls(
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images_with_metadata",
|
||||
response_model=PaginatedResults[ImageDTO],
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images_with_metadata(
|
||||
image_type: ImageType = Query(description="The type of images to list"),
|
||||
image_category: ImageCategory = Query(description="The kind of images to list"),
|
||||
page: int = Query(default=0, description="The page of image metadata to get"),
|
||||
per_page: int = Query(
|
||||
default=10, description="The number of image metadata per page"
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list"
|
||||
),
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images with metadata"""
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include"
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
image_type,
|
||||
image_category,
|
||||
page,
|
||||
per_page,
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
@ -16,6 +16,7 @@ from pydantic.fields import Field
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
|
||||
|
||||
@ -229,6 +230,7 @@ def invoke_cli():
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
@ -236,6 +238,7 @@ def invoke_cli():
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType, ImageCategory
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -163,7 +163,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
@ -177,8 +177,8 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
@ -187,7 +187,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
|
@ -7,7 +7,7 @@ import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get_pil_image(
|
||||
self.mask.image_type, self.mask.image_name
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
|
||||
# Convert to cv image/mask
|
||||
@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_inpainted,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
@ -10,9 +10,9 @@ import torch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
@ -86,8 +86,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get(
|
||||
self.control_image.image_type, self.control_image.image_name
|
||||
else context.services.images.get_pil_image(
|
||||
self.control_image.image_origin, self.control_image.image_name
|
||||
)
|
||||
)
|
||||
# loading controlnet model
|
||||
@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generate_output.image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
|
||||
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
@ -7,7 +7,7 @@ import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ImageType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_type=self.image.image_type,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
if image:
|
||||
image.show()
|
||||
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_type=self.image.image_type,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_crop = Image.new(
|
||||
@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_crop,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
self.base_image.image_origin, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get_pil_image(
|
||||
self.mask.image_type, self.mask.image_name
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=new_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_mask,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(
|
||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(
|
||||
self.image1.image_type, self.image1.image_name
|
||||
self.image1.image_origin, self.image1.image_name
|
||||
)
|
||||
image2 = context.services.images.get_pil_image(
|
||||
self.image2.image_type, self.image2.image_name
|
||||
self.image2.image_origin, self.image2.image_name
|
||||
)
|
||||
|
||||
multiply_image = ImageChops.multiply(image1, image2)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=multiply_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=channel_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=converted_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
blur = (
|
||||
@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=blur_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -409,7 +409,116 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
PIL_RESAMPLING_MODES = Literal[
|
||||
"nearest",
|
||||
"box",
|
||||
"bilinear",
|
||||
"hamming",
|
||||
"bicubic",
|
||||
"lanczos",
|
||||
]
|
||||
|
||||
|
||||
PIL_RESAMPLING_MAP = {
|
||||
"nearest": Image.Resampling.NEAREST,
|
||||
"box": Image.Resampling.BOX,
|
||||
"bilinear": Image.Resampling.BILINEAR,
|
||||
"hamming": Image.Resampling.HAMMING,
|
||||
"bicubic": Image.Resampling.BICUBIC,
|
||||
"lanczos": Image.Resampling.LANCZOS,
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
resize_image = image.resize(
|
||||
(self.width, self.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
width = int(image.width * self.scale_factor)
|
||||
height = int(image.height * self.scale_factor)
|
||||
|
||||
resize_image = image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -430,7 +539,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
@ -440,7 +549,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=lerp_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -450,7 +559,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -471,7 +580,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
@ -486,7 +595,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=ilerp_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -496,7 +605,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_file_storage import ImageType
|
||||
from ..services.image_file_storage import ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .compel import ConditioningField
|
||||
@ -297,7 +297,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_type,
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
||||
control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
# and gnenerate unique image_name
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
|
@ -2,7 +2,7 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
@ -4,7 +4,7 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_type=ImageType.RESULT,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
|
@ -5,30 +5,52 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
||||
"""The type of an image."""
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
RESULT = "results"
|
||||
UPLOAD = "uploads"
|
||||
- INTERNAL: The resource was created by the application.
|
||||
- EXTERNAL: The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
"""The resource was created by the application."""
|
||||
EXTERNAL = "external"
|
||||
"""The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
|
||||
class InvalidImageTypeException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageType.
|
||||
class InvalidOriginException(ValueError):
|
||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image type."):
|
||||
def __init__(self, message="Invalid resource origin."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
||||
"""The category of an image.
|
||||
|
||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
- MASK: The image is a mask image.
|
||||
- CONTROL: The image is a ControlNet control image.
|
||||
- USER: The image is a user-provide image.
|
||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
"""
|
||||
|
||||
GENERAL = "general"
|
||||
CONTROL = "control"
|
||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||
MASK = "mask"
|
||||
"""MASK: The image is a mask image."""
|
||||
CONTROL = "control"
|
||||
"""CONTROL: The image is a ControlNet control image."""
|
||||
USER = "user"
|
||||
"""USER: The image is a user-provide image."""
|
||||
OTHER = "other"
|
||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||
|
||||
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
@ -44,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
image_origin: ResourceOrigin = Field(
|
||||
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
@ -61,3 +83,11 @@ class ColorField(BaseModel):
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Optional
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from typing import Any
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
for image_origin in ResourceOrigin:
|
||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
image_path = self.get_path(image_origin, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
|
@ -1,21 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
deserialize_image_record,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
offset: int = Field(description="Offset from which to retrieve items")
|
||||
limit: int = Field(description="Limit of items to get")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
@ -46,7 +60,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@ -54,7 +68,7 @@ class ImageRecordStorageBase(ABC):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
@ -63,18 +77,19 @@ class ImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@ -82,7 +97,7 @@ class ImageRecordStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
@ -103,7 +118,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
@ -129,7 +143,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_type TEXT NOT NULL,
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
@ -138,9 +152,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
@ -155,7 +169,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
@ -182,7 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
||||
def get(
|
||||
self, image_origin: ResourceOrigin, image_name: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -209,7 +225,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
@ -224,7 +240,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
@ -235,6 +251,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -244,36 +271,61 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(image_type.value, image_category.value, per_page, page * per_page),
|
||||
)
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
category_strings = list(
|
||||
map(lambda c: c.value, set(categories))
|
||||
)
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
images_params.append(limit)
|
||||
images_params.append(offset)
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*) FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
""",
|
||||
(image_type.value, image_category.value),
|
||||
)
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -281,13 +333,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -307,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
@ -325,7 +375,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_type,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
@ -338,7 +388,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_type.value,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
|
@ -1,14 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
import uuid
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
ResourceOrigin,
|
||||
InvalidImageCategoryException,
|
||||
InvalidImageTypeException,
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
@ -16,6 +15,7 @@ from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
@ -31,6 +31,7 @@ from invokeai.app.services.image_file_storage import (
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -44,7 +45,7 @@ class ImageServiceABC(ABC):
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
@ -56,7 +57,7 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
@ -64,22 +65,22 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@ -90,7 +91,7 @@ class ImageServiceABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
@ -98,16 +99,17 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@ -120,6 +122,7 @@ class ImageServiceDependencies:
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
names: NameServiceBase
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
|
||||
def __init__(
|
||||
@ -129,6 +132,7 @@ class ImageServiceDependencies:
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
@ -136,6 +140,7 @@ class ImageServiceDependencies:
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.names = names
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
|
||||
@ -149,6 +154,7 @@ class ImageService(ImageServiceABC):
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
@ -157,30 +163,26 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
if image_type not in ImageType:
|
||||
raise InvalidImageTypeException
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self._create_image_name(
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
image_name = self._services.names.create_image_name()
|
||||
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
|
||||
@ -191,7 +193,7 @@ class ImageService(ImageServiceABC):
|
||||
created_at = self._services.records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
@ -204,21 +206,21 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_type, image_name, True
|
||||
image_origin, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
@ -247,24 +249,23 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, image_type, changes)
|
||||
return self.get_dto(image_type, image_name)
|
||||
self._services.records.update(image_name, image_origin, changes)
|
||||
return self.get_dto(image_origin, image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_type, image_name)
|
||||
return self._services.files.get(image_origin, image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@ -272,9 +273,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_type, image_name)
|
||||
return self._services.records.get(image_origin, image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -282,14 +283,14 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_type, image_name)
|
||||
image_record = self._services.records.get(image_origin, image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_type, image_name),
|
||||
self._services.urls.get_image_url(image_type, image_name, True),
|
||||
self._services.urls.get_image_url(image_origin, image_name),
|
||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -301,10 +302,10 @@ class ImageService(ImageServiceABC):
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@ -317,57 +318,58 @@ class ImageService(ImageServiceABC):
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
image_type,
|
||||
image_category,
|
||||
page,
|
||||
per_page,
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(image_type, r.image_name),
|
||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
image_type, r.image_name, True
|
||||
r.image_origin, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedResults[ImageDTO](
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
page=results.page,
|
||||
pages=results.pages,
|
||||
per_page=results.per_page,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_type, image_name)
|
||||
self._services.records.delete(image_type, image_name)
|
||||
self._services.files.delete(image_origin, image_name)
|
||||
self._services.records.delete(image_origin, image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
@ -378,21 +380,6 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def _create_image_name(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Create a unique image name."""
|
||||
uuid_str = str(uuid.uuid4())
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||
|
||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Union[ImageMetadata, None]:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Extra, Field, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
"""The type of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
"""The category of the image."""
|
||||
width: int = Field(description="The width of the image in px.")
|
||||
@ -56,6 +56,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
@ -67,6 +68,10 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
@ -74,8 +79,8 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
"""The type of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
@ -105,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
@ -127,7 +134,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
|
30
invokeai/app/services/resource_name.py
Normal file
30
invokeai/app/services/resource_name.py
Normal file
@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
import uuid
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
|
||||
IMAGE = "image"
|
||||
LATENT = "latent"
|
||||
|
||||
|
||||
class NameServiceBase(ABC):
|
||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
@abstractmethod
|
||||
def create_image_name(self) -> str:
|
||||
"""Creates a name for an image."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleNameService(NameServiceBase):
|
||||
"""Creates image names from UUIDs."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return (
|
||||
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
|
||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||
)
|
||||
|
||||
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
|
||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||
|
@ -1,5 +1,5 @@
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
|
@ -122,7 +122,9 @@
|
||||
"noImagesInGallery": "No Images In Gallery",
|
||||
"deleteImage": "Delete Image",
|
||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored."
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
"images": "Images",
|
||||
"assets": "Assets"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||
@ -524,7 +526,7 @@
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
"displayInProgress": "Display In-Progress Images",
|
||||
"displayInProgress": "Display Progress Images",
|
||||
"saveSteps": "Save images every n steps",
|
||||
"confirmOnDelete": "Confirm On Delete",
|
||||
"displayHelpIcons": "Display Help Icons",
|
||||
|
@ -1,7 +1,5 @@
|
||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
||||
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
||||
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
@ -22,11 +20,9 @@ const serializationDenylist: {
|
||||
models: modelsPersistDenylist,
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
results: resultsPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
// config: configPersistDenyList,
|
||||
ui: uiPersistDenylist,
|
||||
uploads: uploadsPersistDenylist,
|
||||
// hotkeys: hotkeysPersistDenylist,
|
||||
};
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { initialResultsState } from 'features/gallery/store/resultsSlice';
|
||||
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
|
||||
import { initialImagesState } from 'features/gallery/store/imagesSlice';
|
||||
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
@ -24,12 +23,11 @@ const initialStates: {
|
||||
models: initialModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
results: initialResultsState,
|
||||
system: initialSystemState,
|
||||
config: initialConfigState,
|
||||
ui: initialUIState,
|
||||
uploads: initialUploadsState,
|
||||
hotkeys: initialHotkeysState,
|
||||
images: initialImagesState,
|
||||
};
|
||||
|
||||
export const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
@ -7,5 +7,6 @@ export const actionsDenylist = [
|
||||
'canvas/setBoundingBoxDimensions',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/addPointToCurrentLine',
|
||||
'socket/generatorProgress',
|
||||
'socket/socketGeneratorProgress',
|
||||
'socket/appSocketGeneratorProgress',
|
||||
];
|
||||
|
@ -26,15 +26,15 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress';
|
||||
import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete';
|
||||
import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete';
|
||||
import { addInvocationErrorListener } from './listeners/socketio/invocationError';
|
||||
import { addInvocationStartedListener } from './listeners/socketio/invocationStarted';
|
||||
import { addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||
import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
|
||||
import {
|
||||
addImageMetadataReceivedFulfilledListener,
|
||||
@ -60,13 +60,16 @@ import {
|
||||
addSessionCanceledRejectedListener,
|
||||
} from './listeners/sessionCanceled';
|
||||
import {
|
||||
addReceivedResultImagesPageFulfilledListener,
|
||||
addReceivedResultImagesPageRejectedListener,
|
||||
} from './listeners/receivedResultImagesPage';
|
||||
addImageUpdatedFulfilledListener,
|
||||
addImageUpdatedRejectedListener,
|
||||
} from './listeners/imageUpdated';
|
||||
import {
|
||||
addReceivedUploadImagesPageFulfilledListener,
|
||||
addReceivedUploadImagesPageRejectedListener,
|
||||
} from './listeners/receivedUploadImagesPage';
|
||||
addReceivedPageOfImagesFulfilledListener,
|
||||
addReceivedPageOfImagesRejectedListener,
|
||||
} from './listeners/receivedPageOfImages';
|
||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
||||
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -90,6 +93,11 @@ export type AppListenerEffect = ListenerEffect<
|
||||
addImageUploadedFulfilledListener();
|
||||
addImageUploadedRejectedListener();
|
||||
|
||||
// Image updated
|
||||
addImageUpdatedFulfilledListener();
|
||||
addImageUpdatedRejectedListener();
|
||||
|
||||
// Image selected
|
||||
addInitialImageSelectedListener();
|
||||
|
||||
// Image deleted
|
||||
@ -118,8 +126,22 @@ addCanvasSavedToGalleryListener();
|
||||
addCanvasDownloadedAsImageListener();
|
||||
addCanvasCopiedToClipboardListener();
|
||||
addCanvasMergedListener();
|
||||
addStagingAreaImageSavedListener();
|
||||
addCommitStagingAreaImageListener();
|
||||
|
||||
// socketio
|
||||
/**
|
||||
* Socket.IO Events - these handle SIO events directly and pass on internal application actions.
|
||||
* We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't
|
||||
* actually be handled at all.
|
||||
*
|
||||
* For example, we don't want to respond to progress events for canceled sessions. To avoid
|
||||
* duplicating the logic to determine if an event should be responded to, we handle all of that
|
||||
* "is this session canceled?" logic in these listeners.
|
||||
*
|
||||
* The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress`
|
||||
* action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress`
|
||||
* action that is handled by reducers in slices.
|
||||
*/
|
||||
addGeneratorProgressListener();
|
||||
addGraphExecutionStateCompleteListener();
|
||||
addInvocationCompleteListener();
|
||||
@ -145,8 +167,9 @@ addSessionCanceledPendingListener();
|
||||
addSessionCanceledFulfilledListener();
|
||||
addSessionCanceledRejectedListener();
|
||||
|
||||
// Gallery pages
|
||||
addReceivedResultImagesPageFulfilledListener();
|
||||
addReceivedResultImagesPageRejectedListener();
|
||||
addReceivedUploadImagesPageFulfilledListener();
|
||||
addReceivedUploadImagesPageRejectedListener();
|
||||
// Fetching images
|
||||
addReceivedPageOfImagesFulfilledListener();
|
||||
addReceivedPageOfImagesRejectedListener();
|
||||
|
||||
// Gallery
|
||||
addImageCategoriesChangedListener();
|
||||
|
@ -0,0 +1,42 @@
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
|
||||
import { sessionCanceled } from 'services/thunks/session';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvas' });
|
||||
|
||||
export const addCommitStagingAreaImageListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: commitStagingAreaImage,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const { sessionId, isProcessing } = state.system;
|
||||
const canvasSessionId = action.payload;
|
||||
|
||||
if (!isProcessing) {
|
||||
// Only need to cancel if we are processing
|
||||
return;
|
||||
}
|
||||
|
||||
if (!canvasSessionId) {
|
||||
moduleLog.debug('No canvas session, skipping cancel');
|
||||
return;
|
||||
}
|
||||
|
||||
if (canvasSessionId !== sessionId) {
|
||||
moduleLog.debug(
|
||||
{
|
||||
data: {
|
||||
canvasSessionId,
|
||||
sessionId,
|
||||
},
|
||||
},
|
||||
'Canvas session does not match global session, skipping cancel'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(sessionCanceled({ sessionId }));
|
||||
},
|
||||
});
|
||||
};
|
@ -55,6 +55,8 @@ export const addCanvasMergedListener = () => {
|
||||
formData: {
|
||||
file: new File([blob], filename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
|
||||
|
@ -4,16 +4,18 @@ import { log } from 'app/logging/useLogger';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
export const addCanvasSavedToGalleryListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
const blob = await getBaseLayerBlob(state, true);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
@ -27,13 +29,25 @@ export const addCanvasSavedToGalleryListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const filename = `mergedCanvas_${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
||||
file: new File([blob], filename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: false,
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload: uploadedImageDTO }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === filename
|
||||
);
|
||||
|
||||
dispatch(imageUpserted(uploadedImageDTO));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,24 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import {
|
||||
imageCategoriesChanged,
|
||||
selectFilteredImagesAsArray,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
export const addImageCategoriesChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageCategoriesChanged,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(
|
||||
getState()
|
||||
).length;
|
||||
|
||||
if (!filteredImagesCount) {
|
||||
dispatch(receivedPageOfImages());
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -4,8 +4,12 @@ import { imageDeleted } from 'services/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
|
||||
import { resultsAdapter } from 'features/gallery/store/resultsSlice';
|
||||
import {
|
||||
imageRemoved,
|
||||
imagesAdapter,
|
||||
selectImagesEntities,
|
||||
selectImagesIds,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
|
||||
@ -22,19 +26,20 @@ export const addRequestedImageDeletionListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { image_name, image_type } = image;
|
||||
const { image_name, image_origin } = image;
|
||||
|
||||
const selectedImageName = getState().gallery.selectedImage?.image_name;
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage;
|
||||
|
||||
if (selectedImageName === image_name) {
|
||||
const allIds = getState()[image_type].ids;
|
||||
const allEntities = getState()[image_type].entities;
|
||||
if (selectedImage && selectedImage.image_name === image_name) {
|
||||
const ids = selectImagesIds(state);
|
||||
const entities = selectImagesEntities(state);
|
||||
|
||||
const deletedImageIndex = allIds.findIndex(
|
||||
const deletedImageIndex = ids.findIndex(
|
||||
(result) => result.toString() === image_name
|
||||
);
|
||||
|
||||
const filteredIds = allIds.filter((id) => id.toString() !== image_name);
|
||||
const filteredIds = ids.filter((id) => id.toString() !== image_name);
|
||||
|
||||
const newSelectedImageIndex = clamp(
|
||||
deletedImageIndex,
|
||||
@ -44,7 +49,7 @@ export const addRequestedImageDeletionListener = () => {
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
|
||||
const newSelectedImage = allEntities[newSelectedImageId];
|
||||
const newSelectedImage = entities[newSelectedImageId];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImage));
|
||||
@ -53,7 +58,11 @@ export const addRequestedImageDeletionListener = () => {
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(imageDeleted({ imageName: image_name, imageType: image_type }));
|
||||
dispatch(imageRemoved(image_name));
|
||||
|
||||
dispatch(
|
||||
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
@ -65,14 +74,9 @@ export const addImageDeletedPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageDeleted.pending,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { imageName, imageType } = action.meta.arg;
|
||||
const { imageName, imageOrigin } = action.meta.arg;
|
||||
// Preemptively remove the image from the gallery
|
||||
if (imageType === 'uploads') {
|
||||
uploadsAdapter.removeOne(getState().uploads, imageName);
|
||||
}
|
||||
if (imageType === 'results') {
|
||||
resultsAdapter.removeOne(getState().results, imageName);
|
||||
}
|
||||
imagesAdapter.removeOne(getState().images, imageName);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,14 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import {
|
||||
ResultsImageDTO,
|
||||
resultUpserted,
|
||||
} from 'features/gallery/store/resultsSlice';
|
||||
import {
|
||||
UploadsImageDTO,
|
||||
uploadUpserted,
|
||||
} from 'features/gallery/store/uploadsSlice';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
@ -17,15 +10,12 @@ export const addImageMetadataReceivedFulfilledListener = () => {
|
||||
actionCreator: imageMetadataReceived.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const image = action.payload;
|
||||
if (image.is_intermediate) {
|
||||
// No further actions needed for intermediate images
|
||||
return;
|
||||
}
|
||||
moduleLog.debug({ data: { image } }, 'Image metadata received');
|
||||
|
||||
if (image.image_type === 'results') {
|
||||
dispatch(resultUpserted(action.payload as ResultsImageDTO));
|
||||
}
|
||||
|
||||
if (image.image_type === 'uploads') {
|
||||
dispatch(uploadUpserted(action.payload as UploadsImageDTO));
|
||||
}
|
||||
dispatch(imageUpserted(image));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,26 @@
|
||||
import { startAppListening } from '..';
|
||||
import { imageUpdated } from 'services/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
export const addImageUpdatedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageUpdated.fulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
moduleLog.debug(
|
||||
{ oldImage: action.meta.arg, updatedImage: action.payload },
|
||||
'Image updated'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageUpdatedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageUpdated.rejected,
|
||||
effect: (action, { dispatch }) => {
|
||||
moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed');
|
||||
},
|
||||
});
|
||||
};
|
@ -1,52 +1,28 @@
|
||||
import { startAppListening } from '..';
|
||||
import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { resultUpserted } from 'features/gallery/store/resultsSlice';
|
||||
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
export const addImageUploadedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.payload.is_intermediate === false,
|
||||
actionCreator: imageUploaded.fulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const image = action.payload;
|
||||
|
||||
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
|
||||
|
||||
if (action.payload.is_intermediate) {
|
||||
// No further actions needed for intermediate images
|
||||
return;
|
||||
}
|
||||
|
||||
const state = getState();
|
||||
|
||||
// Handle uploads
|
||||
if (isUploadsImageDTO(image)) {
|
||||
dispatch(uploadUpserted(image));
|
||||
|
||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'img2img') {
|
||||
dispatch(initialImageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle results
|
||||
// TODO: Can this ever happen? I don't think so...
|
||||
if (isResultsImageDTO(image)) {
|
||||
dispatch(resultUpserted(image));
|
||||
}
|
||||
dispatch(imageUpserted(image));
|
||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||
},
|
||||
});
|
||||
};
|
||||
@ -55,6 +31,9 @@ export const addImageUploadedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageUploaded.rejected,
|
||||
effect: (action, { dispatch }) => {
|
||||
const { formData, ...rest } = action.meta.arg;
|
||||
const sanitizedData = { arg: { ...rest, formData: { file: '<Blob>' } } };
|
||||
moduleLog.error({ data: sanitizedData }, 'Image upload failed');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Image Upload Failed',
|
||||
|
@ -1,8 +1,7 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
import { resultsAdapter } from 'features/gallery/store/resultsSlice';
|
||||
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
|
||||
import { imagesAdapter } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
@ -13,27 +12,15 @@ export const addImageUrlsReceivedFulfilledListener = () => {
|
||||
const image = action.payload;
|
||||
moduleLog.debug({ data: { image } }, 'Image URLs received');
|
||||
|
||||
const { image_type, image_name, image_url, thumbnail_url } = image;
|
||||
const { image_name, image_url, thumbnail_url } = image;
|
||||
|
||||
if (image_type === 'results') {
|
||||
resultsAdapter.updateOne(getState().results, {
|
||||
id: image_name,
|
||||
changes: {
|
||||
image_url,
|
||||
thumbnail_url,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (image_type === 'uploads') {
|
||||
uploadsAdapter.updateOne(getState().uploads, {
|
||||
id: image_name,
|
||||
changes: {
|
||||
image_url,
|
||||
thumbnail_url,
|
||||
},
|
||||
});
|
||||
}
|
||||
imagesAdapter.updateOne(getState().images, {
|
||||
id: image_name,
|
||||
changes: {
|
||||
image_url,
|
||||
thumbnail_url,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,6 +1,4 @@
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { t } from 'i18next';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
@ -9,7 +7,7 @@ import {
|
||||
isImageDTO,
|
||||
} from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { selectImagesById } from 'features/gallery/store/imagesSlice';
|
||||
|
||||
export const addInitialImageSelectedListener = () => {
|
||||
startAppListening({
|
||||
@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { image_name, image_type } = action.payload;
|
||||
|
||||
let image: ImageDTO | undefined;
|
||||
const state = getState();
|
||||
|
||||
if (image_type === 'results') {
|
||||
image = selectResultsById(state, image_name);
|
||||
} else if (image_type === 'uploads') {
|
||||
image = selectUploadsById(state, image_name);
|
||||
}
|
||||
const imageName = action.payload;
|
||||
const image = selectImagesById(getState(), imageName);
|
||||
|
||||
if (!image) {
|
||||
dispatch(
|
||||
|
@ -1,31 +1,31 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedResultImagesPage } from 'services/thunks/gallery';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
export const addReceivedResultImagesPageFulfilledListener = () => {
|
||||
export const addReceivedPageOfImagesFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedResultImagesPage.fulfilled,
|
||||
actionCreator: receivedPageOfImages.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const page = action.payload;
|
||||
moduleLog.debug(
|
||||
{ data: { page } },
|
||||
`Received ${page.items.length} results`
|
||||
{ data: { payload: action.payload } },
|
||||
`Received ${page.items.length} images`
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addReceivedResultImagesPageRejectedListener = () => {
|
||||
export const addReceivedPageOfImagesRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedResultImagesPage.rejected,
|
||||
actionCreator: receivedPageOfImages.rejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (action.payload) {
|
||||
moduleLog.debug(
|
||||
{ data: { error: serializeError(action.payload.error) } },
|
||||
'Problem receiving results'
|
||||
{ data: { error: serializeError(action.payload) } },
|
||||
'Problem receiving images'
|
||||
);
|
||||
}
|
||||
},
|
@ -1,33 +0,0 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedUploadImagesPage } from 'services/thunks/gallery';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
export const addReceivedUploadImagesPageFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedUploadImagesPage.fulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const page = action.payload;
|
||||
moduleLog.debug(
|
||||
{ data: { page } },
|
||||
`Received ${page.items.length} uploads`
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addReceivedUploadImagesPageRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: receivedUploadImagesPage.rejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (action.payload) {
|
||||
moduleLog.debug(
|
||||
{ data: { error: serializeError(action.payload.error) } },
|
||||
'Problem receiving uploads'
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -1,16 +1,13 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { socketConnected } from 'services/events/actions';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addSocketConnectedListener = () => {
|
||||
export const addSocketConnectedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketConnected,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
@ -18,17 +15,12 @@ export const addSocketConnectedListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { results, uploads, models, nodes, config } = getState();
|
||||
const { models, nodes, config, images } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
// These thunks need to be dispatch in middleware; cannot handle in a reducer
|
||||
if (!results.ids.length) {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
if (!uploads.ids.length) {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
if (!images.ids.length) {
|
||||
dispatch(receivedPageOfImages());
|
||||
}
|
||||
|
||||
if (!models.ids.length) {
|
||||
@ -38,6 +30,9 @@ export const addSocketConnectedListener = () => {
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
dispatch(receivedOpenAPISchema());
|
||||
}
|
||||
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketConnected(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,14 +1,19 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { socketDisconnected } from 'services/events/actions';
|
||||
import {
|
||||
socketDisconnected,
|
||||
appSocketDisconnected,
|
||||
} from 'services/events/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addSocketDisconnectedListener = () => {
|
||||
export const addSocketDisconnectedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketDisconnected,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
moduleLog.debug(action.payload, 'Disconnected');
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketDisconnected(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,12 +1,15 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { generatorProgress } from 'services/events/actions';
|
||||
import {
|
||||
appSocketGeneratorProgress,
|
||||
socketGeneratorProgress,
|
||||
} from 'services/events/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addGeneratorProgressListener = () => {
|
||||
export const addGeneratorProgressEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: generatorProgress,
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
if (
|
||||
getState().system.canceledSession ===
|
||||
@ -23,6 +26,9 @@ export const addGeneratorProgressListener = () => {
|
||||
action.payload,
|
||||
`Generator progress (${action.payload.data.node.type})`
|
||||
);
|
||||
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketGeneratorProgress(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,17 +1,22 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { graphExecutionStateComplete } from 'services/events/actions';
|
||||
import {
|
||||
appSocketGraphExecutionStateComplete,
|
||||
socketGraphExecutionStateComplete,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addGraphExecutionStateCompleteListener = () => {
|
||||
export const addGraphExecutionStateCompleteEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: graphExecutionStateComplete,
|
||||
actionCreator: socketGraphExecutionStateComplete,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
moduleLog.debug(
|
||||
action.payload,
|
||||
`Session invocation complete (${action.payload.data.graph_execution_state_id})`
|
||||
);
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketGraphExecutionStateComplete(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,19 +1,21 @@
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import {
|
||||
appSocketInvocationComplete,
|
||||
socketInvocationComplete,
|
||||
} from 'services/events/actions';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { sessionCanceled } from 'services/thunks/session';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
|
||||
export const addInvocationCompleteListener = () => {
|
||||
export const addInvocationCompleteEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: invocationComplete,
|
||||
actionCreator: socketInvocationComplete,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
moduleLog.debug(
|
||||
action.payload,
|
||||
@ -34,13 +36,13 @@ export const addInvocationCompleteListener = () => {
|
||||
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||
const { image_name, image_type } = result.image;
|
||||
const { image_name, image_origin } = result.image;
|
||||
|
||||
// Get its metadata
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageName: image_name,
|
||||
imageType: image_type,
|
||||
imageOrigin: image_origin,
|
||||
})
|
||||
);
|
||||
|
||||
@ -48,27 +50,18 @@ export const addInvocationCompleteListener = () => {
|
||||
imageMetadataReceived.fulfilled.match
|
||||
);
|
||||
|
||||
if (getState().gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(imageDTO));
|
||||
}
|
||||
|
||||
// Handle canvas image
|
||||
if (
|
||||
graph_execution_state_id ===
|
||||
getState().canvas.layerState.stagingArea.sessionId
|
||||
) {
|
||||
const [{ payload: image }] = await take(
|
||||
(
|
||||
action
|
||||
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
||||
imageMetadataReceived.fulfilled.match(action) &&
|
||||
action.payload.image_name === image_name
|
||||
);
|
||||
dispatch(addImageToStagingArea(image));
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
dispatch(progressImageSet(null));
|
||||
}
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketInvocationComplete(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,17 +1,21 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { invocationError } from 'services/events/actions';
|
||||
import {
|
||||
appSocketInvocationError,
|
||||
socketInvocationError,
|
||||
} from 'services/events/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addInvocationErrorListener = () => {
|
||||
export const addInvocationErrorEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: invocationError,
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
moduleLog.error(
|
||||
action.payload,
|
||||
`Invocation error (${action.payload.data.node.type})`
|
||||
);
|
||||
dispatch(appSocketInvocationError(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,12 +1,15 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { invocationStarted } from 'services/events/actions';
|
||||
import {
|
||||
appSocketInvocationStarted,
|
||||
socketInvocationStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addInvocationStartedListener = () => {
|
||||
export const addInvocationStartedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: invocationStarted,
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
if (
|
||||
getState().system.canceledSession ===
|
||||
@ -23,6 +26,7 @@ export const addInvocationStartedListener = () => {
|
||||
action.payload,
|
||||
`Invocation started (${action.payload.data.node.type})`
|
||||
);
|
||||
dispatch(appSocketInvocationStarted(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,10 +1,10 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { socketSubscribed } from 'services/events/actions';
|
||||
import { appSocketSubscribed, socketSubscribed } from 'services/events/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addSocketSubscribedListener = () => {
|
||||
export const addSocketSubscribedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketSubscribed,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
@ -12,6 +12,7 @@ export const addSocketSubscribedListener = () => {
|
||||
action.payload,
|
||||
`Subscribed (${action.payload.sessionId}))`
|
||||
);
|
||||
dispatch(appSocketSubscribed(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,10 +1,13 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { socketUnsubscribed } from 'services/events/actions';
|
||||
import {
|
||||
appSocketUnsubscribed,
|
||||
socketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
export const addSocketUnsubscribedListener = () => {
|
||||
export const addSocketUnsubscribedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketUnsubscribed,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
@ -12,6 +15,7 @@ export const addSocketUnsubscribedListener = () => {
|
||||
action.payload,
|
||||
`Unsubscribed (${action.payload.sessionId})`
|
||||
);
|
||||
dispatch(appSocketUnsubscribed(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,54 @@
|
||||
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUpdated } from 'services/thunks/image';
|
||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvas' });
|
||||
|
||||
export const addStagingAreaImageSavedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: stagingAreaImageSaved,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const { image_name, image_origin } = action.payload;
|
||||
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: image_name,
|
||||
imageOrigin: image_origin,
|
||||
requestBody: {
|
||||
is_intermediate: false,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const [imageUpdatedAction] = await take(
|
||||
(action) =>
|
||||
(imageUpdated.fulfilled.match(action) ||
|
||||
imageUpdated.rejected.match(action)) &&
|
||||
action.meta.arg.imageName === image_name
|
||||
);
|
||||
|
||||
if (imageUpdated.rejected.match(imageUpdatedAction)) {
|
||||
moduleLog.error(
|
||||
{ data: { arg: imageUpdatedAction.meta.arg } },
|
||||
'Image saving failed'
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Image Saving Failed',
|
||||
description: imageUpdatedAction.error.message,
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (imageUpdated.fulfilled.match(imageUpdatedAction)) {
|
||||
dispatch(imageUpserted(imageUpdatedAction.payload));
|
||||
dispatch(addToast({ title: 'Image Saved', status: 'success' }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -101,6 +101,7 @@ export const addUserInvokedCanvasListener = () => {
|
||||
formData: {
|
||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
@ -115,7 +116,7 @@ export const addUserInvokedCanvasListener = () => {
|
||||
// Update the base node with the image name and type
|
||||
baseNode.image = {
|
||||
image_name: baseImageDTO.image_name,
|
||||
image_type: baseImageDTO.image_type,
|
||||
image_origin: baseImageDTO.image_origin,
|
||||
};
|
||||
}
|
||||
|
||||
@ -127,6 +128,7 @@ export const addUserInvokedCanvasListener = () => {
|
||||
formData: {
|
||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||
},
|
||||
imageCategory: 'mask',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
@ -141,7 +143,7 @@ export const addUserInvokedCanvasListener = () => {
|
||||
// Update the base node with the image name and type
|
||||
baseNode.mask = {
|
||||
image_name: maskImageDTO.image_name,
|
||||
image_type: maskImageDTO.image_type,
|
||||
image_origin: maskImageDTO.image_origin,
|
||||
};
|
||||
}
|
||||
|
||||
@ -158,7 +160,7 @@ export const addUserInvokedCanvasListener = () => {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.image.image_name,
|
||||
imageType: baseNode.image.image_type,
|
||||
imageOrigin: baseNode.image.image_origin,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
@ -169,7 +171,7 @@ export const addUserInvokedCanvasListener = () => {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.mask.image_name,
|
||||
imageType: baseNode.mask.image_type,
|
||||
imageOrigin: baseNode.mask.image_origin,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
|
@ -10,8 +10,7 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import resultsReducer from 'features/gallery/store/resultsSlice';
|
||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
@ -41,12 +40,11 @@ const allReducers = {
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
results: resultsReducer,
|
||||
system: systemReducer,
|
||||
config: configReducer,
|
||||
ui: uiReducer,
|
||||
uploads: uploadsReducer,
|
||||
hotkeys: hotkeysReducer,
|
||||
images: imagesReducer,
|
||||
// session: sessionReducer,
|
||||
};
|
||||
|
||||
@ -65,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'system',
|
||||
'ui',
|
||||
// 'hotkeys',
|
||||
// 'results',
|
||||
// 'uploads',
|
||||
// 'config',
|
||||
];
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
import { SelectedImage } from 'features/parameters/store/actions';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageResponseMetadata, ImageType } from 'services/api';
|
||||
import { ImageResponseMetadata, ResourceOrigin } from 'services/api';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
/**
|
||||
@ -124,7 +124,7 @@ export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
|
||||
*/
|
||||
// export ty`pe Image = {
|
||||
// name: string;
|
||||
// type: ImageType;
|
||||
// type: image_origin;
|
||||
// url: string;
|
||||
// thumbnail: string;
|
||||
// metadata: ImageResponseMetadata;
|
||||
|
@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
type ImageUploadOverlayProps = {
|
||||
isDragAccept: boolean;
|
||||
isDragReject: boolean;
|
||||
overlaySecondaryText: string;
|
||||
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
|
||||
};
|
||||
|
||||
@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
||||
const {
|
||||
isDragAccept,
|
||||
isDragReject: _isDragAccept,
|
||||
overlaySecondaryText,
|
||||
setIsHandlingUpload,
|
||||
} = props;
|
||||
|
||||
@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
||||
}}
|
||||
>
|
||||
{isDragAccept ? (
|
||||
<Heading size="lg">Upload Image{overlaySecondaryText}</Heading>
|
||||
<Heading size="lg">Drop to Upload</Heading>
|
||||
) : (
|
||||
<>
|
||||
<Heading size="lg">Invalid Upload</Heading>
|
||||
|
@ -69,11 +69,12 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
formData: { file },
|
||||
activeTabName,
|
||||
imageCategory: 'user',
|
||||
isIntermediate: false,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, activeTabName]
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
@ -144,14 +145,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
};
|
||||
}, [inputRef, open, setOpenUploaderFunction]);
|
||||
|
||||
const overlaySecondaryText = useMemo(() => {
|
||||
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
|
||||
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
|
||||
}
|
||||
|
||||
return '';
|
||||
}, [t, activeTabName]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
{...getRootProps({ style: {} })}
|
||||
@ -166,7 +159,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
<ImageUploadOverlay
|
||||
isDragAccept={isDragAccept}
|
||||
isDragReject={isDragReject}
|
||||
overlaySecondaryText={overlaySecondaryText}
|
||||
setIsHandlingUpload={setIsHandlingUpload}
|
||||
/>
|
||||
)}
|
||||
|
@ -1,239 +0,0 @@
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import {
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
ControlField,
|
||||
} from 'services/api';
|
||||
|
||||
const OBJECT_TYPESTRING = '[object Object]';
|
||||
const STRING_TYPESTRING = '[object String]';
|
||||
const NUMBER_TYPESTRING = '[object Number]';
|
||||
const BOOLEAN_TYPESTRING = '[object Boolean]';
|
||||
const ARRAY_TYPESTRING = '[object Array]';
|
||||
|
||||
const isObject = (obj: unknown): obj is Record<string | number, any> =>
|
||||
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
|
||||
|
||||
const isString = (obj: unknown): obj is string =>
|
||||
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
|
||||
|
||||
const isNumber = (obj: unknown): obj is number =>
|
||||
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
|
||||
|
||||
const isBoolean = (obj: unknown): obj is boolean =>
|
||||
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
|
||||
|
||||
const isArray = (obj: unknown): obj is Array<any> =>
|
||||
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
|
||||
|
||||
const parseImageField = (imageField: unknown): ImageField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(imageField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField must have both `image_name` and `image_type`
|
||||
if (!('image_name' in imageField && 'image_type' in imageField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField's `image_type` must be one of the allowed values
|
||||
if (
|
||||
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField's `image_name` must be a string
|
||||
if (typeof imageField.image_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid ImageField
|
||||
return {
|
||||
image_type: imageField.image_type,
|
||||
image_name: imageField.image_name,
|
||||
};
|
||||
};
|
||||
|
||||
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(latentsField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A LatentsField must have a `latents_name`
|
||||
if (!('latents_name' in latentsField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A LatentsField's `latents_name` must be a string
|
||||
if (typeof latentsField.latents_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid LatentsField
|
||||
return {
|
||||
latents_name: latentsField.latents_name,
|
||||
};
|
||||
};
|
||||
|
||||
const parseConditioningField = (
|
||||
conditioningField: unknown
|
||||
): ConditioningField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(conditioningField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A ConditioningField must have a `conditioning_name`
|
||||
if (!('conditioning_name' in conditioningField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A ConditioningField's `conditioning_name` must be a string
|
||||
if (typeof conditioningField.conditioning_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid ConditioningField
|
||||
return {
|
||||
conditioning_name: conditioningField.conditioning_name,
|
||||
};
|
||||
};
|
||||
|
||||
const parseControlField = (controlField: unknown): ControlField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(controlField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A ControlField must have a `control`
|
||||
if (!('control' in controlField)) {
|
||||
return;
|
||||
}
|
||||
// console.log(typeof controlField.control);
|
||||
|
||||
// Build a valid ControlField
|
||||
return {
|
||||
control: controlField.control,
|
||||
};
|
||||
};
|
||||
|
||||
type NodeMetadata = {
|
||||
[key: string]:
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| ImageField
|
||||
| LatentsField
|
||||
| ConditioningField
|
||||
| ControlField;
|
||||
};
|
||||
|
||||
type InvokeAIMetadata = {
|
||||
session_id?: string;
|
||||
node?: NodeMetadata;
|
||||
};
|
||||
|
||||
export const parseNodeMetadata = (
|
||||
nodeMetadata: Record<string | number, any>
|
||||
): NodeMetadata | undefined => {
|
||||
if (!isObject(nodeMetadata)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed: NodeMetadata = {};
|
||||
|
||||
forEach(nodeMetadata, (nodeItem, nodeKey) => {
|
||||
// `id` and `type` must be strings if they are present
|
||||
if (['id', 'type'].includes(nodeKey)) {
|
||||
if (isString(nodeItem)) {
|
||||
parsed[nodeKey] = nodeItem;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// the only valid object types are ImageField, LatentsField, ConditioningField, ControlField
|
||||
if (isObject(nodeItem)) {
|
||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||
const imageField = parseImageField(nodeItem);
|
||||
if (imageField) {
|
||||
parsed[nodeKey] = imageField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('latents_name' in nodeItem) {
|
||||
const latentsField = parseLatentsField(nodeItem);
|
||||
if (latentsField) {
|
||||
parsed[nodeKey] = latentsField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('conditioning_name' in nodeItem) {
|
||||
const conditioningField = parseConditioningField(nodeItem);
|
||||
if (conditioningField) {
|
||||
parsed[nodeKey] = conditioningField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('control' in nodeItem) {
|
||||
const controlField = parseControlField(nodeItem);
|
||||
if (controlField) {
|
||||
parsed[nodeKey] = controlField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise we accept any string, number or boolean
|
||||
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
|
||||
parsed[nodeKey] = nodeItem;
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
if (size(parsed) === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
};
|
||||
|
||||
export const parseInvokeAIMetadata = (
|
||||
metadata: Record<string | number, any> | undefined
|
||||
): InvokeAIMetadata | undefined => {
|
||||
if (metadata === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isObject(metadata)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed: InvokeAIMetadata = {};
|
||||
|
||||
forEach(metadata, (item, key) => {
|
||||
if (key === 'session_id' && isString(item)) {
|
||||
parsed['session_id'] = item;
|
||||
}
|
||||
|
||||
if (key === 'node' && isObject(item)) {
|
||||
const nodeMetadata = parseNodeMetadata(item);
|
||||
|
||||
if (nodeMetadata) {
|
||||
parsed['node'] = nodeMetadata;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (size(parsed) === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
};
|
@ -1,18 +1,24 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Image as KonvaImage } from 'react-konva';
|
||||
import { canvasSelector } from '../store/canvasSelectors';
|
||||
|
||||
const selector = createSelector(
|
||||
[(state: RootState) => state.gallery],
|
||||
(gallery: GalleryState) => {
|
||||
return gallery.intermediateImage ? gallery.intermediateImage : null;
|
||||
[systemSelector, canvasSelector],
|
||||
(system, canvas) => {
|
||||
const { progressImage, sessionId } = system;
|
||||
const { sessionId: canvasSessionId, boundingBox } =
|
||||
canvas.layerState.stagingArea;
|
||||
|
||||
return {
|
||||
boundingBox,
|
||||
progressImage: sessionId === canvasSessionId ? progressImage : undefined,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
@ -25,33 +31,34 @@ type Props = Omit<ImageConfig, 'image'>;
|
||||
|
||||
const IAICanvasIntermediateImage = (props: Props) => {
|
||||
const { ...rest } = props;
|
||||
const intermediateImage = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
const { progressImage, boundingBox } = useAppSelector(selector);
|
||||
const [loadedImageElement, setLoadedImageElement] =
|
||||
useState<HTMLImageElement | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!intermediateImage) return;
|
||||
if (!progressImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
const tempImage = new Image();
|
||||
|
||||
tempImage.onload = () => {
|
||||
setLoadedImageElement(tempImage);
|
||||
};
|
||||
tempImage.src = getUrl(intermediateImage.url);
|
||||
}, [intermediateImage, getUrl]);
|
||||
|
||||
if (!intermediateImage?.boundingBox) return null;
|
||||
tempImage.src = progressImage.dataURL;
|
||||
}, [progressImage]);
|
||||
|
||||
const {
|
||||
boundingBox: { x, y, width, height },
|
||||
} = intermediateImage;
|
||||
if (!(progressImage && boundingBox)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return loadedImageElement ? (
|
||||
<KonvaImage
|
||||
x={x}
|
||||
y={y}
|
||||
width={width}
|
||||
height={height}
|
||||
x={boundingBox.x}
|
||||
y={boundingBox.y}
|
||||
width={boundingBox.width}
|
||||
height={boundingBox.height}
|
||||
image={loadedImageElement}
|
||||
listening={false}
|
||||
{...rest}
|
||||
|
@ -62,7 +62,7 @@ const IAICanvasStagingArea = (props: Props) => {
|
||||
<Group {...rest}>
|
||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||
<IAICanvasImage
|
||||
url={getUrl(currentStagingAreaImage.image.image_url)}
|
||||
url={getUrl(currentStagingAreaImage.image.image_url) ?? ''}
|
||||
x={x}
|
||||
y={y}
|
||||
/>
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
// import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
@ -26,13 +25,14 @@ import {
|
||||
FaPlus,
|
||||
FaSave,
|
||||
} from 'react-icons/fa';
|
||||
import { stagingAreaImageSaved } from '../store/actions';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector],
|
||||
(canvas) => {
|
||||
const {
|
||||
layerState: {
|
||||
stagingArea: { images, selectedImageIndex },
|
||||
stagingArea: { images, selectedImageIndex, sessionId },
|
||||
},
|
||||
shouldShowStagingOutline,
|
||||
shouldShowStagingImage,
|
||||
@ -45,6 +45,7 @@ const selector = createSelector(
|
||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||
shouldShowStagingImage,
|
||||
shouldShowStagingOutline,
|
||||
sessionId,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
isOnLastImage,
|
||||
currentStagingAreaImage,
|
||||
shouldShowStagingImage,
|
||||
sessionId,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
}
|
||||
);
|
||||
|
||||
const handlePrevImage = () => dispatch(prevStagingAreaImage());
|
||||
const handleNextImage = () => dispatch(nextStagingAreaImage());
|
||||
const handleAccept = () => dispatch(commitStagingAreaImage());
|
||||
const handlePrevImage = useCallback(
|
||||
() => dispatch(prevStagingAreaImage()),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleNextImage = useCallback(
|
||||
() => dispatch(nextStagingAreaImage()),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleAccept = useCallback(
|
||||
() => dispatch(commitStagingAreaImage(sessionId)),
|
||||
[dispatch, sessionId]
|
||||
);
|
||||
|
||||
if (!currentStagingAreaImage) return null;
|
||||
|
||||
@ -157,19 +170,15 @@ const IAICanvasStagingAreaToolbar = () => {
|
||||
}
|
||||
colorScheme="accent"
|
||||
/>
|
||||
{/* <IAIIconButton
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.saveToGallery')}
|
||||
aria-label={t('unifiedCanvas.saveToGallery')}
|
||||
icon={<FaSave />}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
saveStagingAreaImageToGallery(
|
||||
currentStagingAreaImage.image.image_url
|
||||
)
|
||||
)
|
||||
dispatch(stagingAreaImageSaved(currentStagingAreaImage.image))
|
||||
}
|
||||
colorScheme="accent"
|
||||
/> */}
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={t('unifiedCanvas.discardAll')}
|
||||
aria-label={t('unifiedCanvas.discardAll')}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageDTO } from 'services/api';
|
||||
|
||||
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
||||
|
||||
@ -11,3 +12,7 @@ export const canvasDownloadedAsImage = createAction(
|
||||
);
|
||||
|
||||
export const canvasMerged = createAction('canvas/canvasMerged');
|
||||
|
||||
export const stagingAreaImageSaved = createAction<ImageDTO>(
|
||||
'canvas/stagingAreaImageSaved'
|
||||
);
|
||||
|
@ -696,7 +696,10 @@ export const canvasSlice = createSlice({
|
||||
0
|
||||
);
|
||||
},
|
||||
commitStagingAreaImage: (state) => {
|
||||
commitStagingAreaImage: (
|
||||
state,
|
||||
action: PayloadAction<string | undefined>
|
||||
) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
return;
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual, isString } from 'lodash-es';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import {
|
||||
ButtonGroup,
|
||||
@ -25,8 +25,8 @@ import {
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldHidePreview,
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -37,18 +37,14 @@ import {
|
||||
FaDownload,
|
||||
FaExpand,
|
||||
FaExpandArrowsAlt,
|
||||
FaEye,
|
||||
FaEyeSlash,
|
||||
FaGrinStars,
|
||||
FaHourglassHalf,
|
||||
FaQuoteRight,
|
||||
FaSeedling,
|
||||
FaShare,
|
||||
FaShareAlt,
|
||||
FaTrash,
|
||||
FaWrench,
|
||||
} from 'react-icons/fa';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { useCallback } from 'react';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
@ -90,7 +86,11 @@ const currentImageButtonsSelector = createSelector(
|
||||
|
||||
const { isLightboxOpen } = lightbox;
|
||||
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
|
||||
const { selectedImage } = gallery;
|
||||
|
||||
@ -112,6 +112,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
seed: selectedImage?.metadata?.seed,
|
||||
prompt: selectedImage?.metadata?.positive_conditioning,
|
||||
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
||||
shouldShowProgressInViewer,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -145,6 +146,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
image,
|
||||
canDeleteImage,
|
||||
shouldConfirmOnDelete,
|
||||
shouldShowProgressInViewer,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}, [toaster, shouldTransformUrls, getUrl, t, image]);
|
||||
|
||||
const handlePreviewVisibility = useCallback(() => {
|
||||
dispatch(setShouldHidePreview(!shouldHidePreview));
|
||||
}, [dispatch, shouldHidePreview]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
|
||||
|
||||
const handleClickProgressImagesToggle = useCallback(() => {
|
||||
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
|
||||
}, [dispatch, shouldShowProgressInViewer]);
|
||||
|
||||
useHotkeys('delete', handleInitiateDelete, [
|
||||
image,
|
||||
shouldConfirmOnDelete,
|
||||
@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
isDisabled={!image}
|
||||
aria-label={`${t('parameters.sendTo')}...`}
|
||||
tooltip={`${t('parameters.sendTo')}...`}
|
||||
isDisabled={!image}
|
||||
icon={<FaShareAlt />}
|
||||
/>
|
||||
}
|
||||
@ -465,21 +468,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</Link>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
{/* <IAIIconButton
|
||||
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
||||
tooltip={
|
||||
!shouldHidePreview
|
||||
? t('parameters.hidePreview')
|
||||
: t('parameters.showPreview')
|
||||
}
|
||||
aria-label={
|
||||
!shouldHidePreview
|
||||
? t('parameters.hidePreview')
|
||||
: t('parameters.showPreview')
|
||||
}
|
||||
isChecked={shouldHidePreview}
|
||||
onClick={handlePreviewVisibility}
|
||||
/> */}
|
||||
{isLightboxEnabled && (
|
||||
<IAIIconButton
|
||||
icon={<FaExpand />}
|
||||
@ -604,6 +592,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<IAIIconButton
|
||||
aria-label={t('settings.displayInProgress')}
|
||||
tooltip={t('settings.displayInProgress')}
|
||||
icon={<FaHourglassHalf />}
|
||||
isChecked={shouldShowProgressInViewer}
|
||||
onClick={handleClickProgressImagesToggle}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup isAttached={true}>
|
||||
<DeleteImageButton image={image} />
|
||||
</ButtonGroup>
|
||||
|
@ -62,7 +62,6 @@ const CurrentImagePreview = () => {
|
||||
return;
|
||||
}
|
||||
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
||||
e.dataTransfer.setData('invokeai/imageType', image.image_type);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
},
|
||||
[image]
|
||||
|
@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const handleDragStart = useCallback(
|
||||
(e: DragEvent<HTMLDivElement>) => {
|
||||
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
||||
e.dataTransfer.setData('invokeai/imageType', image.image_type);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
},
|
||||
[image]
|
||||
|
@ -1,6 +1,8 @@
|
||||
import {
|
||||
Box,
|
||||
ButtonGroup,
|
||||
Checkbox,
|
||||
CheckboxGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Grid,
|
||||
@ -16,7 +18,6 @@ import IAIPopover from 'common/components/IAIPopover';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
setCurrentCategory,
|
||||
setGalleryImageMinimumWidth,
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
@ -36,54 +37,48 @@ import {
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
|
||||
import { FaImage, FaUser, FaWrench } from 'react-icons/fa';
|
||||
import {
|
||||
FaFilter,
|
||||
FaImage,
|
||||
FaImages,
|
||||
FaServer,
|
||||
FaWrench,
|
||||
} from 'react-icons/fa';
|
||||
import { MdPhotoLibrary } from 'react-icons/md';
|
||||
import HoverableImage from './HoverableImage';
|
||||
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { resultsAdapter } from '../store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import GalleryProgressImage from './GalleryProgressImage';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { ImageDTO } from 'services/api';
|
||||
|
||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
|
||||
import { ImageCategory } from 'services/api';
|
||||
import {
|
||||
ASSETS_CATEGORIES,
|
||||
IMAGE_CATEGORIES,
|
||||
imageCategoriesChanged,
|
||||
selectImagesAll,
|
||||
} from '../store/imagesSlice';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import { capitalize } from 'lodash-es';
|
||||
|
||||
const categorySelector = createSelector(
|
||||
[(state: RootState) => state],
|
||||
(state) => {
|
||||
const { results, uploads, system, gallery } = state;
|
||||
const { currentCategory } = gallery;
|
||||
const { images } = state;
|
||||
const { categories } = images;
|
||||
|
||||
if (currentCategory === 'results') {
|
||||
const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
|
||||
|
||||
if (system.progressImage) {
|
||||
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
|
||||
}
|
||||
|
||||
return {
|
||||
images: tempImages.concat(
|
||||
resultsAdapter.getSelectors().selectAll(results)
|
||||
),
|
||||
isLoading: results.isLoading,
|
||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||
};
|
||||
}
|
||||
const allImages = selectImagesAll(state);
|
||||
const filteredImages = allImages.filter((i) =>
|
||||
categories.includes(i.image_category)
|
||||
);
|
||||
|
||||
return {
|
||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||
isLoading: uploads.isLoading,
|
||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||
images: filteredImages,
|
||||
isLoading: images.isLoading,
|
||||
areMoreImagesAvailable: filteredImages.length < images.total,
|
||||
categories: images.categories,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -93,7 +88,6 @@ const mainSelector = createSelector(
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery, ui) => {
|
||||
const {
|
||||
currentCategory,
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
@ -104,7 +98,6 @@ const mainSelector = createSelector(
|
||||
const { shouldPinGallery } = ui;
|
||||
|
||||
return {
|
||||
currentCategory,
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
@ -120,7 +113,6 @@ const ImageGalleryContent = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
||||
const rootRef = useRef(null);
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
@ -137,7 +129,6 @@ const ImageGalleryContent = () => {
|
||||
});
|
||||
|
||||
const {
|
||||
currentCategory,
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
@ -146,18 +137,12 @@ const ImageGalleryContent = () => {
|
||||
selectedImage,
|
||||
} = useAppSelector(mainSelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading } =
|
||||
const { images, areMoreImagesAvailable, isLoading, categories } =
|
||||
useAppSelector(categorySelector);
|
||||
|
||||
const handleClickLoadMore = () => {
|
||||
if (currentCategory === 'results') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
if (currentCategory === 'uploads') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
};
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(receivedPageOfImages());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||
dispatch(setGalleryImageMinimumWidth(v));
|
||||
@ -168,28 +153,6 @@ const ImageGalleryContent = () => {
|
||||
dispatch(requestCanvasRescale());
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!resizeObserverRef.current) {
|
||||
return;
|
||||
}
|
||||
const resizeObserver = new ResizeObserver(() => {
|
||||
if (!resizeObserverRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH
|
||||
) {
|
||||
setShouldShouldIconButtons(true);
|
||||
return;
|
||||
}
|
||||
|
||||
setShouldShouldIconButtons(false);
|
||||
});
|
||||
resizeObserver.observe(resizeObserverRef.current);
|
||||
return () => resizeObserver.disconnect(); // clean up
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const { current: root } = rootRef;
|
||||
if (scroller && root) {
|
||||
@ -210,12 +173,23 @@ const ImageGalleryContent = () => {
|
||||
}, []);
|
||||
|
||||
const handleEndReached = useCallback(() => {
|
||||
if (currentCategory === 'results') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
} else if (currentCategory === 'uploads') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
}, [dispatch, currentCategory]);
|
||||
handleLoadMoreImages();
|
||||
}, [handleLoadMoreImages]);
|
||||
|
||||
const handleCategoriesChanged = useCallback(
|
||||
(newCategories: ImageCategory[]) => {
|
||||
dispatch(imageCategoriesChanged(newCategories));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleClickImagesCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickAssetsCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -232,59 +206,31 @@ const ImageGalleryContent = () => {
|
||||
alignItems="center"
|
||||
justifyContent="space-between"
|
||||
>
|
||||
<ButtonGroup
|
||||
size="sm"
|
||||
isAttached
|
||||
w="max-content"
|
||||
justifyContent="stretch"
|
||||
>
|
||||
{shouldShouldIconButtons ? (
|
||||
<>
|
||||
<IAIIconButton
|
||||
aria-label={t('gallery.showGenerations')}
|
||||
tooltip={t('gallery.showGenerations')}
|
||||
isChecked={currentCategory === 'results'}
|
||||
role="radio"
|
||||
icon={<FaImage />}
|
||||
onClick={() => dispatch(setCurrentCategory('results'))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
aria-label={t('gallery.showUploads')}
|
||||
tooltip={t('gallery.showUploads')}
|
||||
role="radio"
|
||||
isChecked={currentCategory === 'uploads'}
|
||||
icon={<FaUser />}
|
||||
onClick={() => dispatch(setCurrentCategory('uploads'))}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
isChecked={currentCategory === 'results'}
|
||||
onClick={() => dispatch(setCurrentCategory('results'))}
|
||||
flexGrow={1}
|
||||
>
|
||||
{t('gallery.generations')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
isChecked={currentCategory === 'uploads'}
|
||||
onClick={() => dispatch(setCurrentCategory('uploads'))}
|
||||
flexGrow={1}
|
||||
>
|
||||
{t('gallery.uploads')}
|
||||
</IAIButton>
|
||||
</>
|
||||
)}
|
||||
<ButtonGroup isAttached>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.images')}
|
||||
aria-label={t('gallery.images')}
|
||||
onClick={handleClickImagesCategory}
|
||||
isChecked={categories === IMAGE_CATEGORIES}
|
||||
size="sm"
|
||||
icon={<FaImage />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.assets')}
|
||||
aria-label={t('gallery.assets')}
|
||||
onClick={handleClickAssetsCategory}
|
||||
isChecked={categories === ASSETS_CATEGORIES}
|
||||
size="sm"
|
||||
icon={<FaServer />}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<Flex gap={2}>
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
tooltip={t('gallery.gallerySettings')}
|
||||
aria-label={t('gallery.gallerySettings')}
|
||||
size="sm"
|
||||
icon={<FaWrench />}
|
||||
/>
|
||||
}
|
||||
@ -347,28 +293,17 @@ const ImageGalleryContent = () => {
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||
itemContent={(index, image) => {
|
||||
const isSelected =
|
||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||
? false
|
||||
: selectedImage?.image_name === image?.image_name;
|
||||
|
||||
return (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||
<GalleryProgressImage
|
||||
key={PROGRESS_IMAGE_PLACEHOLDER}
|
||||
/>
|
||||
) : (
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}}
|
||||
itemContent={(index, image) => (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={
|
||||
selectedImage?.image_name === image?.image_name
|
||||
}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
/>
|
||||
) : (
|
||||
<VirtuosoGrid
|
||||
@ -380,27 +315,20 @@ const ImageGalleryContent = () => {
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, image) => {
|
||||
const isSelected =
|
||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||
? false
|
||||
: selectedImage?.image_name === image?.image_name;
|
||||
|
||||
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
|
||||
) : (
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
);
|
||||
}}
|
||||
itemContent={(index, image) => (
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={
|
||||
selectedImage?.image_name === image?.image_name
|
||||
}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
<IAIButton
|
||||
onClick={handleClickLoadMore}
|
||||
onClick={handleLoadMoreImages}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
loadingText="Loading"
|
||||
|
@ -53,6 +53,11 @@ const MetadataItem = ({
|
||||
withCopy = false,
|
||||
}: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onClick && (
|
||||
|
@ -9,6 +9,10 @@ import { gallerySelector } from '../store/gallerySelectors';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { imageSelected } from '../store/gallerySlice';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import {
|
||||
selectFilteredImagesAsObject,
|
||||
selectFilteredImagesIds,
|
||||
} from '../store/imagesSlice';
|
||||
|
||||
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
|
||||
height: '100%',
|
||||
@ -21,9 +25,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
||||
};
|
||||
|
||||
export const nextPrevImageButtonsSelector = createSelector(
|
||||
[(state: RootState) => state, gallerySelector],
|
||||
(state, gallery) => {
|
||||
const { selectedImage, currentCategory } = gallery;
|
||||
[
|
||||
(state: RootState) => state,
|
||||
gallerySelector,
|
||||
selectFilteredImagesAsObject,
|
||||
selectFilteredImagesIds,
|
||||
],
|
||||
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
|
||||
const { selectedImage } = gallery;
|
||||
|
||||
if (!selectedImage) {
|
||||
return {
|
||||
@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector(
|
||||
};
|
||||
}
|
||||
|
||||
const currentImageIndex = state[currentCategory].ids.findIndex(
|
||||
const currentImageIndex = filteredImageIds.findIndex(
|
||||
(i) => i === selectedImage.image_name
|
||||
);
|
||||
|
||||
const nextImageIndex = clamp(
|
||||
currentImageIndex + 1,
|
||||
0,
|
||||
state[currentCategory].ids.length - 1
|
||||
filteredImageIds.length - 1
|
||||
);
|
||||
|
||||
const prevImageIndex = clamp(
|
||||
currentImageIndex - 1,
|
||||
0,
|
||||
state[currentCategory].ids.length - 1
|
||||
filteredImageIds.length - 1
|
||||
);
|
||||
|
||||
const nextImageId = state[currentCategory].ids[nextImageIndex];
|
||||
const prevImageId = state[currentCategory].ids[prevImageIndex];
|
||||
const nextImageId = filteredImageIds[nextImageIndex];
|
||||
const prevImageId = filteredImageIds[prevImageIndex];
|
||||
|
||||
const nextImage = state[currentCategory].entities[nextImageId];
|
||||
const prevImage = state[currentCategory].entities[prevImageId];
|
||||
const nextImage = filteredImagesAsObject[nextImageId];
|
||||
const prevImage = filteredImagesAsObject[prevImageId];
|
||||
|
||||
const imagesLength = state[currentCategory].ids.length;
|
||||
const imagesLength = filteredImageIds.length;
|
||||
|
||||
return {
|
||||
isOnFirstImage: currentImageIndex === 0,
|
||||
|
@ -1,33 +1,18 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ImageType } from 'services/api';
|
||||
import { selectResultsEntities } from '../store/resultsSlice';
|
||||
import { selectUploadsEntities } from '../store/uploadsSlice';
|
||||
import { selectImagesEntities } from '../store/imagesSlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
const useGetImageByNameSelector = createSelector(
|
||||
[selectResultsEntities, selectUploadsEntities],
|
||||
(allResults, allUploads) => {
|
||||
return { allResults, allUploads };
|
||||
}
|
||||
);
|
||||
|
||||
const useGetImageByNameAndType = () => {
|
||||
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
|
||||
return (name: string, type: ImageType) => {
|
||||
if (type === 'results') {
|
||||
const resultImagesResult = allResults[name];
|
||||
if (resultImagesResult) {
|
||||
return resultImagesResult;
|
||||
const useGetImageByName = () => {
|
||||
const images = useAppSelector(selectImagesEntities);
|
||||
return useCallback(
|
||||
(name: string | undefined) => {
|
||||
if (!name) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (type === 'uploads') {
|
||||
const userImagesResult = allUploads[name];
|
||||
if (userImagesResult) {
|
||||
return userImagesResult;
|
||||
}
|
||||
}
|
||||
};
|
||||
return images[name];
|
||||
},
|
||||
[images]
|
||||
);
|
||||
};
|
||||
|
||||
export default useGetImageByNameAndType;
|
||||
export default useGetImageByName;
|
||||
|
@ -1,9 +1,9 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { ImageNameAndType } from 'features/parameters/store/actions';
|
||||
import { ImageNameAndOrigin } from 'features/parameters/store/actions';
|
||||
import { ImageDTO } from 'services/api';
|
||||
|
||||
export const requestedImageDeletion = createAction<
|
||||
ImageDTO | ImageNameAndType | undefined
|
||||
ImageDTO | ImageNameAndOrigin | undefined
|
||||
>('gallery/requestedImageDeletion');
|
||||
|
||||
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
|
||||
|
@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice';
|
||||
* Gallery slice persist denylist
|
||||
*/
|
||||
export const galleryPersistDenylist: (keyof GalleryState)[] = [
|
||||
'currentCategory',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
];
|
||||
|
@ -1,10 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from '../../../services/thunks/gallery';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { imageUpserted } from './imagesSlice';
|
||||
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
@ -14,7 +11,6 @@ export interface GalleryState {
|
||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||
shouldAutoSwitchToNewImages: boolean;
|
||||
shouldUseSingleGalleryColumn: boolean;
|
||||
currentCategory: 'results' | 'uploads';
|
||||
}
|
||||
|
||||
export const initialGalleryState: GalleryState = {
|
||||
@ -22,7 +18,6 @@ export const initialGalleryState: GalleryState = {
|
||||
galleryImageObjectFit: 'cover',
|
||||
shouldAutoSwitchToNewImages: true,
|
||||
shouldUseSingleGalleryColumn: false,
|
||||
currentCategory: 'results',
|
||||
};
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
@ -46,12 +41,6 @@ export const gallerySlice = createSlice({
|
||||
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAutoSwitchToNewImages = action.payload;
|
||||
},
|
||||
setCurrentCategory: (
|
||||
state,
|
||||
action: PayloadAction<'results' | 'uploads'>
|
||||
) => {
|
||||
state.currentCategory = action.payload;
|
||||
},
|
||||
setShouldUseSingleGalleryColumn: (
|
||||
state,
|
||||
action: PayloadAction<boolean>
|
||||
@ -59,37 +48,10 @@ export const gallerySlice = createSlice({
|
||||
state.shouldUseSingleGalleryColumn = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
// rehydrate selectedImage URL when results list comes in
|
||||
// solves case when outdated URL is in local storage
|
||||
const selectedImage = state.selectedImage;
|
||||
if (selectedImage) {
|
||||
const selectedImageInResults = action.payload.items.find(
|
||||
(image) => image.image_name === selectedImage.image_name
|
||||
);
|
||||
|
||||
if (selectedImageInResults) {
|
||||
selectedImage.image_url = selectedImageInResults.image_url;
|
||||
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
|
||||
state.selectedImage = selectedImage;
|
||||
}
|
||||
}
|
||||
});
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
// rehydrate selectedImage URL when results list comes in
|
||||
// solves case when outdated URL is in local storage
|
||||
const selectedImage = state.selectedImage;
|
||||
if (selectedImage) {
|
||||
const selectedImageInResults = action.payload.items.find(
|
||||
(image) => image.image_name === selectedImage.image_name
|
||||
);
|
||||
|
||||
if (selectedImageInResults) {
|
||||
selectedImage.image_url = selectedImageInResults.image_url;
|
||||
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
|
||||
state.selectedImage = selectedImage;
|
||||
}
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(imageUpserted, (state, action) => {
|
||||
if (state.shouldAutoSwitchToNewImages) {
|
||||
state.selectedImage = action.payload;
|
||||
}
|
||||
});
|
||||
},
|
||||
@ -101,7 +63,6 @@ export const {
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
setShouldUseSingleGalleryColumn,
|
||||
setCurrentCategory,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
export default gallerySlice.reducer;
|
||||
|
135
invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
Normal file
135
invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
Normal file
@ -0,0 +1,135 @@
|
||||
import {
|
||||
PayloadAction,
|
||||
createEntityAdapter,
|
||||
createSelector,
|
||||
createSlice,
|
||||
} from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { ImageCategory, ImageDTO } from 'services/api';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
import { isString, keyBy } from 'lodash-es';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
|
||||
export const imagesAdapter = createEntityAdapter<ImageDTO>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||
});
|
||||
|
||||
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||
export const ASSETS_CATEGORIES: ImageCategory[] = [
|
||||
'control',
|
||||
'mask',
|
||||
'user',
|
||||
'other',
|
||||
];
|
||||
|
||||
type AdditionaImagesState = {
|
||||
offset: number;
|
||||
limit: number;
|
||||
total: number;
|
||||
isLoading: boolean;
|
||||
categories: ImageCategory[];
|
||||
};
|
||||
|
||||
export const initialImagesState =
|
||||
imagesAdapter.getInitialState<AdditionaImagesState>({
|
||||
offset: 0,
|
||||
limit: 0,
|
||||
total: 0,
|
||||
isLoading: false,
|
||||
categories: IMAGE_CATEGORIES,
|
||||
});
|
||||
|
||||
export type ImagesState = typeof initialImagesState;
|
||||
|
||||
const imagesSlice = createSlice({
|
||||
name: 'images',
|
||||
initialState: initialImagesState,
|
||||
reducers: {
|
||||
imageUpserted: (state, action: PayloadAction<ImageDTO>) => {
|
||||
imagesAdapter.upsertOne(state, action.payload);
|
||||
},
|
||||
imageRemoved: (state, action: PayloadAction<string | ImageDTO>) => {
|
||||
if (isString(action.payload)) {
|
||||
imagesAdapter.removeOne(state, action.payload);
|
||||
return;
|
||||
}
|
||||
|
||||
imagesAdapter.removeOne(state, action.payload.image_name);
|
||||
},
|
||||
imageCategoriesChanged: (state, action: PayloadAction<ImageCategory[]>) => {
|
||||
state.categories = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(receivedPageOfImages.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
builder.addCase(receivedPageOfImages.rejected, (state) => {
|
||||
state.isLoading = false;
|
||||
});
|
||||
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
|
||||
state.isLoading = false;
|
||||
const { items, offset, limit, total } = action.payload;
|
||||
state.offset = offset;
|
||||
state.limit = limit;
|
||||
state.total = total;
|
||||
imagesAdapter.upsertMany(state, items);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectImagesAll,
|
||||
selectById: selectImagesById,
|
||||
selectEntities: selectImagesEntities,
|
||||
selectIds: selectImagesIds,
|
||||
selectTotal: selectImagesTotal,
|
||||
} = imagesAdapter.getSelectors<RootState>((state) => state.images);
|
||||
|
||||
export const { imageUpserted, imageRemoved, imageCategoriesChanged } =
|
||||
imagesSlice.actions;
|
||||
|
||||
export default imagesSlice.reducer;
|
||||
|
||||
export const selectFilteredImagesAsArray = createSelector(
|
||||
(state: RootState) => state,
|
||||
(state) => {
|
||||
const {
|
||||
images: { categories },
|
||||
} = state;
|
||||
|
||||
return selectImagesAll(state).filter((i) =>
|
||||
categories.includes(i.image_category)
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
export const selectFilteredImagesAsObject = createSelector(
|
||||
(state: RootState) => state,
|
||||
(state) => {
|
||||
const {
|
||||
images: { categories },
|
||||
} = state;
|
||||
|
||||
return keyBy(
|
||||
selectImagesAll(state).filter((i) =>
|
||||
categories.includes(i.image_category)
|
||||
),
|
||||
'image_name'
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
export const selectFilteredImagesIds = createSelector(
|
||||
(state: RootState) => state,
|
||||
(state) => {
|
||||
const {
|
||||
images: { categories },
|
||||
} = state;
|
||||
|
||||
return selectImagesAll(state)
|
||||
.filter((i) => categories.includes(i.image_category))
|
||||
.map((i) => i.image_name);
|
||||
}
|
||||
);
|
@ -1,8 +0,0 @@
|
||||
import { ResultsState } from './resultsSlice';
|
||||
|
||||
/**
|
||||
* Results slice persist denylist
|
||||
*
|
||||
* Currently denylisting results slice entirely, see `serialize.ts`
|
||||
*/
|
||||
export const resultsPersistDenylist: (keyof ResultsState)[] = [];
|
@ -1,88 +0,0 @@
|
||||
import {
|
||||
PayloadAction,
|
||||
createEntityAdapter,
|
||||
createSlice,
|
||||
} from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
|
||||
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
||||
image_type: 'results';
|
||||
};
|
||||
|
||||
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||
});
|
||||
|
||||
type AdditionalResultsState = {
|
||||
page: number;
|
||||
pages: number;
|
||||
isLoading: boolean;
|
||||
nextPage: number;
|
||||
upsertedImageCount: number;
|
||||
};
|
||||
|
||||
export const initialResultsState =
|
||||
resultsAdapter.getInitialState<AdditionalResultsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
isLoading: false,
|
||||
nextPage: 0,
|
||||
upsertedImageCount: 0,
|
||||
});
|
||||
|
||||
export type ResultsState = typeof initialResultsState;
|
||||
|
||||
const resultsSlice = createSlice({
|
||||
name: 'results',
|
||||
initialState: initialResultsState,
|
||||
reducers: {
|
||||
resultUpserted: (state, action: PayloadAction<ResultsImageDTO>) => {
|
||||
resultsAdapter.upsertOne(state, action.payload);
|
||||
state.upsertedImageCount += 1;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
/**
|
||||
* Received Result Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Result Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
const { page, pages } = action.payload;
|
||||
|
||||
// We know these will all be of the results type, but it's not represented in the API types
|
||||
const items = action.payload.items as ResultsImageDTO[];
|
||||
|
||||
resultsAdapter.setMany(state, items);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectResultsAll,
|
||||
selectById: selectResultsById,
|
||||
selectEntities: selectResultsEntities,
|
||||
selectIds: selectResultsIds,
|
||||
selectTotal: selectResultsTotal,
|
||||
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
||||
|
||||
export const { resultUpserted } = resultsSlice.actions;
|
||||
|
||||
export default resultsSlice.reducer;
|
@ -1,8 +0,0 @@
|
||||
import { UploadsState } from './uploadsSlice';
|
||||
|
||||
/**
|
||||
* Uploads slice persist denylist
|
||||
*
|
||||
* Currently denylisting uploads slice entirely, see `serialize.ts`
|
||||
*/
|
||||
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
|
@ -1,89 +0,0 @@
|
||||
import {
|
||||
PayloadAction,
|
||||
createEntityAdapter,
|
||||
createSlice,
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
receivedUploadImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
|
||||
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
||||
image_type: 'uploads';
|
||||
};
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||
});
|
||||
|
||||
type AdditionalUploadsState = {
|
||||
page: number;
|
||||
pages: number;
|
||||
isLoading: boolean;
|
||||
nextPage: number;
|
||||
upsertedImageCount: number;
|
||||
};
|
||||
|
||||
export const initialUploadsState =
|
||||
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
nextPage: 0,
|
||||
isLoading: false,
|
||||
upsertedImageCount: 0,
|
||||
});
|
||||
|
||||
export type UploadsState = typeof initialUploadsState;
|
||||
|
||||
const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: initialUploadsState,
|
||||
reducers: {
|
||||
uploadUpserted: (state, action: PayloadAction<UploadsImageDTO>) => {
|
||||
uploadsAdapter.upsertOne(state, action.payload);
|
||||
state.upsertedImageCount += 1;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
/**
|
||||
* Received Upload Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Upload Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
const { page, pages } = action.payload;
|
||||
|
||||
// We know these will all be of the uploads type, but it's not represented in the API types
|
||||
const items = action.payload.items as UploadsImageDTO[];
|
||||
|
||||
uploadsAdapter.setMany(state, items);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectUploadsAll,
|
||||
selectById: selectUploadsById,
|
||||
selectEntities: selectUploadsEntities,
|
||||
selectIds: selectUploadsIds,
|
||||
selectTotal: selectUploadsTotal,
|
||||
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
||||
|
||||
export const { uploadUpserted } = uploadsSlice.actions;
|
||||
|
||||
export default uploadsSlice.reducer;
|
@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName';
|
||||
import useGetImageByName from 'features/gallery/hooks/useGetImageByName';
|
||||
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
@ -11,7 +11,6 @@ import {
|
||||
} from 'features/nodes/types/types';
|
||||
import { DragEvent, memo, useCallback, useState } from 'react';
|
||||
|
||||
import { ImageType } from 'services/api';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ImageInputFieldComponent = (
|
||||
@ -19,7 +18,7 @@ const ImageInputFieldComponent = (
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const getImageByNameAndType = useGetImageByNameAndType();
|
||||
const getImageByName = useGetImageByName();
|
||||
const dispatch = useAppDispatch();
|
||||
const [url, setUrl] = useState<string | undefined>(field.value?.image_url);
|
||||
const { getUrl } = useGetUrl();
|
||||
@ -27,13 +26,7 @@ const ImageInputFieldComponent = (
|
||||
const handleDrop = useCallback(
|
||||
(e: DragEvent<HTMLDivElement>) => {
|
||||
const name = e.dataTransfer.getData('invokeai/imageName');
|
||||
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
|
||||
|
||||
if (!name || !type) {
|
||||
return;
|
||||
}
|
||||
|
||||
const image = getImageByNameAndType(name, type);
|
||||
const image = getImageByName(name);
|
||||
|
||||
if (!image) {
|
||||
return;
|
||||
@ -49,7 +42,7 @@ const ImageInputFieldComponent = (
|
||||
})
|
||||
);
|
||||
},
|
||||
[getImageByNameAndType, dispatch, field.name, nodeId]
|
||||
[getImageByName, dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -26,18 +26,21 @@ const buildBaseNode = (
|
||||
| ImageToImageInvocation
|
||||
| InpaintInvocation
|
||||
| undefined => {
|
||||
const dimensionsOverride = state.canvas.boundingBoxDimensions;
|
||||
const overrides = {
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
is_intermediate: true,
|
||||
};
|
||||
|
||||
if (nodeType === 'txt2img') {
|
||||
return buildTxt2ImgNode(state, dimensionsOverride);
|
||||
return buildTxt2ImgNode(state, overrides);
|
||||
}
|
||||
|
||||
if (nodeType === 'img2img') {
|
||||
return buildImg2ImgNode(state, dimensionsOverride);
|
||||
return buildImg2ImgNode(state, overrides);
|
||||
}
|
||||
|
||||
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
|
||||
return buildInpaintNode(state, dimensionsOverride);
|
||||
return buildInpaintNode(state, overrides);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -64,7 +64,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
|
||||
model,
|
||||
image: {
|
||||
image_name: initialImage?.image_name,
|
||||
image_type: initialImage?.image_type,
|
||||
image_origin: initialImage?.image_origin,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -58,7 +58,7 @@ export const buildImg2ImgNode = (
|
||||
|
||||
imageToImageNode.image = {
|
||||
image_name: initialImage.name,
|
||||
image_type: initialImage.type,
|
||||
image_origin: initialImage.type,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ export const buildInpaintNode = (
|
||||
|
||||
inpaintNode.image = {
|
||||
image_name: initialImage.name,
|
||||
image_type: initialImage.type,
|
||||
image_origin: initialImage.type,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ import {
|
||||
buildOutputFieldTemplates,
|
||||
} from './fieldTemplateBuilders';
|
||||
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'meta'];
|
||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
|
||||
|
||||
const invocationDenylist = ['Graph', 'InvocationMeta'];
|
||||
|
||||
|
@ -15,7 +15,7 @@ const ParamInfillCollapse = () => {
|
||||
|
||||
return (
|
||||
<IAICollapse
|
||||
label={t('parameters.boundingBoxHeader')}
|
||||
label={t('parameters.infillScalingHeader')}
|
||||
isOpen={isOpen}
|
||||
onToggle={onToggle}
|
||||
>
|
||||
|
@ -5,7 +5,6 @@ import { useGetUrl } from 'common/util/getUrl';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { DragEvent, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImageType } from 'services/api';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
@ -55,9 +54,7 @@ const InitialImagePreview = () => {
|
||||
const handleDrop = useCallback(
|
||||
(e: DragEvent<HTMLDivElement>) => {
|
||||
const name = e.dataTransfer.getData('invokeai/imageName');
|
||||
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
|
||||
|
||||
dispatch(initialImageSelected({ image_name: name, image_type: type }));
|
||||
dispatch(initialImageSelected(name));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
@ -88,7 +88,7 @@ export const useParameters = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(initialImageSelected(image));
|
||||
dispatch(initialImageSelected(image.image_name));
|
||||
toaster({
|
||||
title: t('toast.initialImageSet'),
|
||||
status: 'info',
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { ImageDTO, ImageType } from 'services/api';
|
||||
import { ImageDTO, ResourceOrigin } from 'services/api';
|
||||
|
||||
export type ImageNameAndType = {
|
||||
export type ImageNameAndOrigin = {
|
||||
image_name: string;
|
||||
image_type: ImageType;
|
||||
image_origin: ResourceOrigin;
|
||||
};
|
||||
|
||||
export const isImageDTO = (image: any): image is ImageDTO => {
|
||||
@ -13,8 +13,8 @@ export const isImageDTO = (image: any): image is ImageDTO => {
|
||||
isObject(image) &&
|
||||
'image_name' in image &&
|
||||
image?.image_name !== undefined &&
|
||||
'image_type' in image &&
|
||||
image?.image_type !== undefined &&
|
||||
'image_origin' in image &&
|
||||
image?.image_origin !== undefined &&
|
||||
'image_url' in image &&
|
||||
image?.image_url !== undefined &&
|
||||
'thumbnail_url' in image &&
|
||||
@ -26,6 +26,6 @@ export const isImageDTO = (image: any): image is ImageDTO => {
|
||||
);
|
||||
};
|
||||
|
||||
export const initialImageSelected = createAction<
|
||||
ImageDTO | ImageNameAndType | undefined
|
||||
>('generation/initialImageSelected');
|
||||
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
|
||||
'generation/initialImageSelected'
|
||||
);
|
||||
|
@ -1,34 +1,3 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
export const generationSelector = (state: RootState) => state.generation;
|
||||
|
||||
export const mayGenerateMultipleImagesSelector = createSelector(
|
||||
generationSelector,
|
||||
({ shouldRandomizeSeed, shouldGenerateVariations }) => {
|
||||
return shouldRandomizeSeed || shouldGenerateVariations;
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const initialImageSelector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
const { initialImage } = generation;
|
||||
|
||||
if (initialImage?.type === 'results') {
|
||||
return selectResultsById(state, initialImage.name);
|
||||
}
|
||||
|
||||
if (initialImage?.type === 'uploads') {
|
||||
return selectUploadsById(state, initialImage.name);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -2,17 +2,6 @@ import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
import {
|
||||
generatorProgress,
|
||||
graphExecutionStateComplete,
|
||||
invocationComplete,
|
||||
invocationError,
|
||||
invocationStarted,
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
socketSubscribed,
|
||||
socketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { makeToast } from '../../../app/components/Toaster';
|
||||
@ -30,6 +19,17 @@ import { t } from 'i18next';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { LANGUAGES } from '../components/LanguagePicker';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import {
|
||||
appSocketConnected,
|
||||
appSocketDisconnected,
|
||||
appSocketGeneratorProgress,
|
||||
appSocketGraphExecutionStateComplete,
|
||||
appSocketInvocationComplete,
|
||||
appSocketInvocationError,
|
||||
appSocketInvocationStarted,
|
||||
appSocketSubscribed,
|
||||
appSocketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@ -227,7 +227,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Socket Subscribed
|
||||
*/
|
||||
builder.addCase(socketSubscribed, (state, action) => {
|
||||
builder.addCase(appSocketSubscribed, (state, action) => {
|
||||
state.sessionId = action.payload.sessionId;
|
||||
state.canceledSession = '';
|
||||
});
|
||||
@ -235,14 +235,14 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Socket Unsubscribed
|
||||
*/
|
||||
builder.addCase(socketUnsubscribed, (state) => {
|
||||
builder.addCase(appSocketUnsubscribed, (state) => {
|
||||
state.sessionId = null;
|
||||
});
|
||||
|
||||
/**
|
||||
* Socket Connected
|
||||
*/
|
||||
builder.addCase(socketConnected, (state) => {
|
||||
builder.addCase(appSocketConnected, (state) => {
|
||||
state.isConnected = true;
|
||||
state.isCancelable = true;
|
||||
state.isProcessing = false;
|
||||
@ -257,7 +257,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Socket Disconnected
|
||||
*/
|
||||
builder.addCase(socketDisconnected, (state) => {
|
||||
builder.addCase(appSocketDisconnected, (state) => {
|
||||
state.isConnected = false;
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = true;
|
||||
@ -272,7 +272,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Invocation Started
|
||||
*/
|
||||
builder.addCase(invocationStarted, (state) => {
|
||||
builder.addCase(appSocketInvocationStarted, (state) => {
|
||||
state.isCancelable = true;
|
||||
state.isProcessing = true;
|
||||
state.currentStatusHasSteps = false;
|
||||
@ -286,7 +286,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Generator Progress
|
||||
*/
|
||||
builder.addCase(generatorProgress, (state, action) => {
|
||||
builder.addCase(appSocketGeneratorProgress, (state, action) => {
|
||||
const { step, total_steps, progress_image } = action.payload.data;
|
||||
|
||||
state.isProcessing = true;
|
||||
@ -303,7 +303,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
builder.addCase(appSocketInvocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
|
||||
// state.currentIteration = 0;
|
||||
@ -322,7 +322,7 @@ export const systemSlice = createSlice({
|
||||
/**
|
||||
* Invocation Error
|
||||
*/
|
||||
builder.addCase(invocationError, (state) => {
|
||||
builder.addCase(appSocketInvocationError, (state) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = true;
|
||||
// state.currentIteration = 0;
|
||||
@ -339,7 +339,20 @@ export const systemSlice = createSlice({
|
||||
});
|
||||
|
||||
/**
|
||||
* Session Invoked - PENDING
|
||||
* Graph Execution State Complete
|
||||
*/
|
||||
builder.addCase(appSocketGraphExecutionStateComplete, (state) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = false;
|
||||
state.isCancelScheduled = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.statusTranslationKey = 'common.statusConnected';
|
||||
state.progressImage = null;
|
||||
});
|
||||
|
||||
/**
|
||||
* User Invoked
|
||||
*/
|
||||
|
||||
builder.addCase(userInvoked, (state) => {
|
||||
@ -367,18 +380,6 @@ export const systemSlice = createSlice({
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Session Canceled
|
||||
*/
|
||||
builder.addCase(graphExecutionStateComplete, (state) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = false;
|
||||
state.isCancelScheduled = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.statusTranslationKey = 'common.statusConnected';
|
||||
});
|
||||
|
||||
/**
|
||||
* Received available models from the backend
|
||||
*/
|
||||
|
@ -8,6 +8,7 @@ export type { OpenAPIConfig } from './core/OpenAPI';
|
||||
|
||||
export type { AddInvocation } from './models/AddInvocation';
|
||||
export type { Body_upload_image } from './models/Body_upload_image';
|
||||
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
|
||||
export type { CkptModelInfo } from './models/CkptModelInfo';
|
||||
export type { CollectInvocation } from './models/CollectInvocation';
|
||||
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
|
||||
@ -15,16 +16,23 @@ export type { ColorField } from './models/ColorField';
|
||||
export type { CompelInvocation } from './models/CompelInvocation';
|
||||
export type { CompelOutput } from './models/CompelOutput';
|
||||
export type { ConditioningField } from './models/ConditioningField';
|
||||
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
||||
export type { ControlField } from './models/ControlField';
|
||||
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
||||
export type { ControlOutput } from './models/ControlOutput';
|
||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
|
||||
export type { DivideInvocation } from './models/DivideInvocation';
|
||||
export type { Edge } from './models/Edge';
|
||||
export type { EdgeConnection } from './models/EdgeConnection';
|
||||
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
|
||||
export type { FloatOutput } from './models/FloatOutput';
|
||||
export type { Graph } from './models/Graph';
|
||||
export type { GraphExecutionState } from './models/GraphExecutionState';
|
||||
export type { GraphInvocation } from './models/GraphInvocation';
|
||||
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
||||
export type { HedImageprocessorInvocation } from './models/HedImageprocessorInvocation';
|
||||
export type { HTTPValidationError } from './models/HTTPValidationError';
|
||||
export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
|
||||
export type { ImageCategory } from './models/ImageCategory';
|
||||
@ -39,10 +47,10 @@ export type { ImageMetadata } from './models/ImageMetadata';
|
||||
export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
|
||||
export type { ImageOutput } from './models/ImageOutput';
|
||||
export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
|
||||
export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation';
|
||||
export type { ImageRecordChanges } from './models/ImageRecordChanges';
|
||||
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
||||
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
||||
export type { ImageType } from './models/ImageType';
|
||||
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
|
||||
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
||||
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
|
||||
@ -56,22 +64,32 @@ export type { LatentsField } from './models/LatentsField';
|
||||
export type { LatentsOutput } from './models/LatentsOutput';
|
||||
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
|
||||
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
|
||||
export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation';
|
||||
export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation';
|
||||
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||
export type { MaskOutput } from './models/MaskOutput';
|
||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||
export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation';
|
||||
export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation';
|
||||
export type { ModelsList } from './models/ModelsList';
|
||||
export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||
export type { NoiseOutput } from './models/NoiseOutput';
|
||||
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
|
||||
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
|
||||
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
|
||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
|
||||
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
||||
export type { PromptOutput } from './models/PromptOutput';
|
||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
||||
export type { RangeInvocation } from './models/RangeInvocation';
|
||||
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
|
||||
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
|
||||
export type { ResourceOrigin } from './models/ResourceOrigin';
|
||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||
@ -81,6 +99,7 @@ export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||
export type { VaeRepo } from './models/VaeRepo';
|
||||
export type { ValidationError } from './models/ValidationError';
|
||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||
|
||||
export { ImagesService } from './services/ImagesService';
|
||||
export { ModelsService } from './services/ModelsService';
|
||||
|
@ -12,6 +12,10 @@ export type CannyImageProcessorInvocation = {
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'canny_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
|
@ -12,6 +12,10 @@ export type ContentShuffleImageProcessorInvocation = {
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'content_shuffle_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
|
@ -12,6 +12,10 @@ export type ControlNetInvocation = {
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'controlnet';
|
||||
/**
|
||||
* image to process
|
||||
@ -20,7 +24,7 @@ export type ControlNetInvocation = {
|
||||
/**
|
||||
* control model used
|
||||
*/
|
||||
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace';
|
||||
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
|
||||
/**
|
||||
* weight given to controlnet
|
||||
*/
|
||||
|
@ -0,0 +1,15 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* A collection of floats
|
||||
*/
|
||||
export type FloatCollectionOutput = {
|
||||
type?: 'float_collection';
|
||||
/**
|
||||
* The float collection
|
||||
*/
|
||||
collection?: Array<number>;
|
||||
};
|
||||
|
15
invokeai/frontend/web/src/services/api/models/FloatOutput.ts
Normal file
15
invokeai/frontend/web/src/services/api/models/FloatOutput.ts
Normal file
@ -0,0 +1,15 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* A float output
|
||||
*/
|
||||
export type FloatOutput = {
|
||||
type?: 'float_output';
|
||||
/**
|
||||
* The output float
|
||||
*/
|
||||
param?: number;
|
||||
};
|
||||
|
@ -3,12 +3,16 @@
|
||||
/* eslint-disable */
|
||||
|
||||
import type { AddInvocation } from './AddInvocation';
|
||||
import type { CannyImageProcessorInvocation } from './CannyImageProcessorInvocation';
|
||||
import type { CollectInvocation } from './CollectInvocation';
|
||||
import type { CompelInvocation } from './CompelInvocation';
|
||||
import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleImageProcessorInvocation';
|
||||
import type { ControlNetInvocation } from './ControlNetInvocation';
|
||||
import type { CvInpaintInvocation } from './CvInpaintInvocation';
|
||||
import type { DivideInvocation } from './DivideInvocation';
|
||||
import type { Edge } from './Edge';
|
||||
import type { GraphInvocation } from './GraphInvocation';
|
||||
import type { HedImageprocessorInvocation } from './HedImageprocessorInvocation';
|
||||
import type { ImageBlurInvocation } from './ImageBlurInvocation';
|
||||
import type { ImageChannelInvocation } from './ImageChannelInvocation';
|
||||
import type { ImageConvertInvocation } from './ImageConvertInvocation';
|
||||
@ -17,6 +21,7 @@ import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation';
|
||||
import type { ImageLerpInvocation } from './ImageLerpInvocation';
|
||||
import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation';
|
||||
import type { ImagePasteInvocation } from './ImagePasteInvocation';
|
||||
import type { ImageProcessorInvocation } from './ImageProcessorInvocation';
|
||||
import type { ImageToImageInvocation } from './ImageToImageInvocation';
|
||||
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
|
||||
import type { InfillColorInvocation } from './InfillColorInvocation';
|
||||
@ -26,11 +31,20 @@ import type { InpaintInvocation } from './InpaintInvocation';
|
||||
import type { IterateInvocation } from './IterateInvocation';
|
||||
import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
|
||||
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
|
||||
import type { LineartAnimeImageProcessorInvocation } from './LineartAnimeImageProcessorInvocation';
|
||||
import type { LineartImageProcessorInvocation } from './LineartImageProcessorInvocation';
|
||||
import type { LoadImageInvocation } from './LoadImageInvocation';
|
||||
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
|
||||
import type { MediapipeFaceProcessorInvocation } from './MediapipeFaceProcessorInvocation';
|
||||
import type { MidasDepthImageProcessorInvocation } from './MidasDepthImageProcessorInvocation';
|
||||
import type { MlsdImageProcessorInvocation } from './MlsdImageProcessorInvocation';
|
||||
import type { MultiplyInvocation } from './MultiplyInvocation';
|
||||
import type { NoiseInvocation } from './NoiseInvocation';
|
||||
import type { NormalbaeImageProcessorInvocation } from './NormalbaeImageProcessorInvocation';
|
||||
import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorInvocation';
|
||||
import type { ParamFloatInvocation } from './ParamFloatInvocation';
|
||||
import type { ParamIntInvocation } from './ParamIntInvocation';
|
||||
import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
|
||||
import type { RandomIntInvocation } from './RandomIntInvocation';
|
||||
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
||||
import type { RangeInvocation } from './RangeInvocation';
|
||||
@ -43,6 +57,7 @@ import type { SubtractInvocation } from './SubtractInvocation';
|
||||
import type { TextToImageInvocation } from './TextToImageInvocation';
|
||||
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
|
||||
import type { UpscaleInvocation } from './UpscaleInvocation';
|
||||
import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation';
|
||||
|
||||
export type Graph = {
|
||||
/**
|
||||
@ -52,7 +67,7 @@ export type Graph = {
|
||||
/**
|
||||
* The nodes in this graph
|
||||
*/
|
||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
||||
/**
|
||||
* The connections between nodes and their fields in this graph
|
||||
*/
|
||||
|
@ -4,6 +4,9 @@
|
||||
|
||||
import type { CollectInvocationOutput } from './CollectInvocationOutput';
|
||||
import type { CompelOutput } from './CompelOutput';
|
||||
import type { ControlOutput } from './ControlOutput';
|
||||
import type { FloatCollectionOutput } from './FloatCollectionOutput';
|
||||
import type { FloatOutput } from './FloatOutput';
|
||||
import type { Graph } from './Graph';
|
||||
import type { GraphInvocationOutput } from './GraphInvocationOutput';
|
||||
import type { ImageOutput } from './ImageOutput';
|
||||
@ -42,7 +45,7 @@ export type GraphExecutionState = {
|
||||
/**
|
||||
* The results of node executions
|
||||
*/
|
||||
results: Record<string, (ImageOutput | MaskOutput | PromptOutput | CompelOutput | IntOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||
/**
|
||||
* Errors raised when executing nodes
|
||||
*/
|
||||
|
@ -12,6 +12,10 @@ export type HedImageprocessorInvocation = {
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'hed_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
|
@ -3,6 +3,12 @@
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* The category of an image. Use ImageCategory.OTHER for non-default categories.
|
||||
* The category of an image.
|
||||
*
|
||||
* - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
* - MASK: The image is a mask image.
|
||||
* - CONTROL: The image is a ControlNet control image.
|
||||
* - USER: The image is a user-provide image.
|
||||
* - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
*/
|
||||
export type ImageCategory = 'general' | 'control' | 'mask' | 'other';
|
||||
export type ImageCategory = 'general' | 'mask' | 'control' | 'user' | 'other';
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
import type { ImageCategory } from './ImageCategory';
|
||||
import type { ImageMetadata } from './ImageMetadata';
|
||||
import type { ImageType } from './ImageType';
|
||||
import type { ResourceOrigin } from './ResourceOrigin';
|
||||
|
||||
/**
|
||||
* Deserialized image record, enriched for the frontend with URLs.
|
||||
@ -17,7 +17,7 @@ export type ImageDTO = {
|
||||
/**
|
||||
* The type of the image.
|
||||
*/
|
||||
image_type: ImageType;
|
||||
image_origin: ResourceOrigin;
|
||||
/**
|
||||
* The URL of the image.
|
||||
*/
|
||||
|
@ -2,7 +2,7 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageType } from './ImageType';
|
||||
import type { ResourceOrigin } from './ResourceOrigin';
|
||||
|
||||
/**
|
||||
* An image field used for passing image objects between invocations
|
||||
@ -11,7 +11,7 @@ export type ImageField = {
|
||||
/**
|
||||
* The type of the image
|
||||
*/
|
||||
image_type: ImageType;
|
||||
image_origin: ResourceOrigin;
|
||||
/**
|
||||
* The name of the image
|
||||
*/
|
||||
|
@ -12,6 +12,10 @@ export type ImageProcessorInvocation = {
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'image_processor';
|
||||
/**
|
||||
* image to process
|
||||
|
@ -10,6 +10,7 @@ import type { ImageCategory } from './ImageCategory';
|
||||
* Only limited changes are valid:
|
||||
* - `image_category`: change the category of an image
|
||||
* - `session_id`: change the session associated with an image
|
||||
* - `is_intermediate`: change the image's `is_intermediate` flag
|
||||
*/
|
||||
export type ImageRecordChanges = {
|
||||
/**
|
||||
@ -20,5 +21,9 @@ export type ImageRecordChanges = {
|
||||
* The image's new session ID.
|
||||
*/
|
||||
session_id?: string;
|
||||
/**
|
||||
* The image's new `is_intermediate` flag.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
};
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user