merge with main

This commit is contained in:
Lincoln Stein 2023-06-01 18:09:49 -04:00
commit 98773b20ac
235 changed files with 6634 additions and 3165 deletions

14
.github/CODEOWNERS vendored
View File

@ -2,7 +2,7 @@
/.github/workflows/ @lstein @blessedcoolant /.github/workflows/ @lstein @blessedcoolant
# documentation # documentation
/docs/ @lstein @tildebyte @blessedcoolant /docs/ @lstein @blessedcoolant @hipsterusername
/mkdocs.yml @lstein @blessedcoolant /mkdocs.yml @lstein @blessedcoolant
# nodes # nodes
@ -18,17 +18,17 @@
/invokeai/version @lstein @blessedcoolant /invokeai/version @lstein @blessedcoolant
# web ui # web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein /invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
/invokeai/backend @blessedcoolant @psychedelicious @lstein /invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing # generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 /invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
# front ends # front ends
/invokeai/frontend/CLI @lstein /invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr /invokeai/frontend/install @lstein @ebr
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername /invokeai/frontend/merge @lstein @blessedcoolant
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername /invokeai/frontend/training @lstein @blessedcoolant
/invokeai/frontend/web @psychedelicious @blessedcoolant /invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp

View File

@ -5,6 +5,7 @@ import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -67,7 +68,7 @@ class ApiDependencies:
metadata = CoreMetadataService() metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage( latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents") DiskLatentsStorage(f"{output_folder}/latents")
) )
@ -78,6 +79,7 @@ class ApiDependencies:
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )

View File

@ -1,39 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# invokeai: Optional[InvokeAIMetadata] = Field(
# description="The image's InvokeAI-specific metadata"
# )
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@ -1,13 +1,19 @@
import io import io
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
) )
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -27,10 +33,13 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
) )
async def upload_image( async def upload_image(
file: UploadFile, file: UploadFile,
image_type: ImageType,
request: Request, request: Request,
response: Response, response: Response,
image_category: ImageCategory = ImageCategory.GENERAL, image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):
@ -46,9 +55,11 @@ async def upload_image(
try: try:
image_dto = ApiDependencies.invoker.services.images.create( image_dto = ApiDependencies.invoker.services.images.create(
pil_image, image=pil_image,
image_type, image_origin=ResourceOrigin.EXTERNAL,
image_category, image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
) )
response.status_code = 201 response.status_code = 201
@ -59,41 +70,61 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_type: ImageType = Query(description="The type of image to delete"), image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
try: try:
ApiDependencies.invoker.services.images.delete(image_type, image_name) ApiDependencies.invoker.services.images.delete(image_origin, image_name)
except Exception as e: except Exception as e:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@images_router.patch(
"/{image_origin}/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_origin, image_name, image_changes
)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/metadata", "/{image_origin}/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_metadata(
image_type: ImageType = Path(description="The type of image to get"), image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto( return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
image_type, image_name
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_type}/{image_name}", "/{image_origin}/{image_name}",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -105,7 +136,7 @@ async def get_image_metadata(
}, },
) )
async def get_image_full( async def get_image_full(
image_type: ImageType = Path( image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get" description="The type of full-resolution image file to get"
), ),
image_name: str = Path(description="The name of full-resolution image file to get"), image_name: str = Path(description="The name of full-resolution image file to get"),
@ -113,9 +144,7 @@ async def get_image_full(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
image_type, image_name
)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -131,7 +160,7 @@ async def get_image_full(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/thumbnail", "/{image_origin}/{image_name}/thumbnail",
operation_id="get_image_thumbnail", operation_id="get_image_thumbnail",
response_class=Response, response_class=Response,
responses={ responses={
@ -143,14 +172,14 @@ async def get_image_full(
}, },
) )
async def get_image_thumbnail( async def get_image_thumbnail(
image_type: ImageType = Path(description="The type of thumbnail image file to get"), image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True image_origin, image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -163,25 +192,25 @@ async def get_image_thumbnail(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/urls", "/{image_origin}/{image_name}/urls",
operation_id="get_image_urls", operation_id="get_image_urls",
response_model=ImageUrlsDTO, response_model=ImageUrlsDTO,
) )
async def get_image_urls( async def get_image_urls(
image_type: ImageType = Path(description="The type of the image whose URL to get"), image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO: ) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL""" """Gets an image and thumbnail URL"""
try: try:
image_url = ApiDependencies.invoker.services.images.get_url( image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name image_origin, image_name
) )
thumbnail_url = ApiDependencies.invoker.services.images.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True image_origin, image_name, thumbnail=True
) )
return ImageUrlsDTO( return ImageUrlsDTO(
image_type=image_type, image_origin=image_origin,
image_name=image_name, image_name=image_name,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -193,23 +222,29 @@ async def get_image_urls(
@images_router.get( @images_router.get(
"/", "/",
operation_id="list_images_with_metadata", operation_id="list_images_with_metadata",
response_model=PaginatedResults[ImageDTO], response_model=OffsetPaginatedResults[ImageDTO],
) )
async def list_images_with_metadata( async def list_images_with_metadata(
image_type: ImageType = Query(description="The type of images to list"), image_origin: Optional[ResourceOrigin] = Query(
image_category: ImageCategory = Query(description="The kind of images to list"), default=None, description="The origin of images to list"
page: int = Query(default=0, description="The page of image metadata to get"),
per_page: int = Query(
default=10, description="The number of image metadata per page"
), ),
) -> PaginatedResults[ImageDTO]: categories: Optional[list[ImageCategory]] = Query(
"""Gets a list of images with metadata""" default=None, description="The categories of image to include"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images"""
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
image_type, offset,
image_category, limit,
page, image_origin,
per_page, categories,
is_intermediate,
) )
return image_dtos return image_dtos

View File

@ -12,11 +12,10 @@ from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
import invokeai.version
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import (CoreMetadataService, from invokeai.app.services.metadata import CoreMetadataService
PngMetadataService) from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from .cli.commands import (BaseCommand, CliContext, ExitCli, from .cli.commands import (BaseCommand, CliContext, ExitCli,
@ -232,6 +231,7 @@ def invoke_cli():
metadata = CoreMetadataService() metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
images = ImageService( images = ImageService(
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
@ -239,6 +239,7 @@ def invoke_cli():
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )

View File

@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off #fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on #fmt: on
@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False):
"image", "image",
"latents", "latents",
"model", "model",
"control",
], ],
] ]
tags: List[str] tags: List[str]

View File

@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs # Outputs
collection: list[int] = Field(default=[], description="The int collection") collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection"] = "float_collection"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""

View File

@ -0,0 +1,428 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
import numpy as np
from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from controlnet_aux import (
CannyDetector,
HEDdetector,
LineartDetector,
LineartAnimeDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
)
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# hacked t2l to split to model & subfolder if format is "model,subfolder"
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="processed image")
control_model: Optional[str] = Field(default=None, description="control model used")
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied")
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
}
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info dict")
# fmt: on
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="image to process")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied")
# fmt: on
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""
# fmt: off
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = Field(default=None, description="image to process")
# fmt: on
def run_processor(self, image):
# superclass just passes through image without processing
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
image=processed_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,
width = image_dto.width,
# height=processed_image.height,
height = image_dto.height,
# mode=processed_image.mode,
)
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
# fmt: on
def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
return processed_image
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode")
# fmt: on
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
coarse: bool = Field(default=False, description="whether to use coarse mode")
# fmt: on
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processed_image = lineart_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse)
return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# fmt: on
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# fmt: on
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = openpose_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
hand_and_face=self.hand_and_face,
)
return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# fmt: on
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
# fmt: on
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d)
return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode")
# fmt: on
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble)
return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
f: Union[int | None] = Field(default=256, ge=0, description="cont")
# fmt: on
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f
)
return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
# fmt: on
def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
mask = context.services.images.get_pil_image( mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name self.mask.image_origin, self.mask.image_name
) )
# Convert to cv image/mask # Convert to cv image/mask
@ -57,16 +57,17 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_inpainted, image=image_inpainted,
image_type=ImageType.INTERMEDIATE, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -3,16 +3,20 @@
from functools import partial from functools import partial
from typing import Literal, Optional, Union, get_args from typing import Literal, Optional, Union, get_args
import torch
from diffusers import ControlNetModel
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageType, ColorField, ImageField from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
ResourceOrigin)
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
@ -53,6 +57,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on # fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
@ -73,17 +80,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# Handle invalid model parameter # Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context) model = context.services.model_manager.get_model(self.model,node=self,context=context)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id context.graph_execution_state_id
) )
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate( txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
@ -92,16 +117,17 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generate_output.image, image=generate_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -141,7 +167,7 @@ class ImageToImageInvocation(TextToImageInvocation):
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
) )
@ -172,16 +198,17 @@ class ImageToImageInvocation(TextToImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -253,13 +280,13 @@ class InpaintInvocation(ImageToImageInvocation):
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
) )
# Handle invalid model parameter # Handle invalid model parameter
@ -287,16 +314,17 @@ class InpaintInvocation(ImageToImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageCategory, ImageField, ImageType from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
) )
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image_type=self.image.image_type, image_origin=self.image.image_origin,
), ),
width=image.width, width=image.width,
height=image.height, height=image.height,
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
if image: if image:
image.show() image.show()
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image_type=self.image.image_type, image_origin=self.image.image_origin,
), ),
width=image.width, width=image.width,
height=image.height, height=image.height,
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_crop = Image.new( image_crop = Image.new(
@ -139,16 +139,17 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_crop, image=image_crop,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -171,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image( base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name self.base_image.image_origin, self.base_image.image_name
) )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else ImageOps.invert( else ImageOps.invert(
context.services.images.get_pil_image( context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name self.mask.image_origin, self.mask.image_name
) )
) )
) )
@ -200,16 +201,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=new_image, image=new_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -229,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_mask = image.split()[-1] image_mask = image.split()[-1]
@ -238,15 +240,16 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_mask, image=image_mask,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK, image_category=ImageCategory.MASK,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return MaskOutput( return MaskOutput(
mask=ImageField( mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -266,25 +269,26 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image( image1 = context.services.images.get_pil_image(
self.image1.image_type, self.image1.image_name self.image1.image_origin, self.image1.image_name
) )
image2 = context.services.images.get_pil_image( image2 = context.services.images.get_pil_image(
self.image2.image_type, self.image2.image_name self.image2.image_origin, self.image2.image_name
) )
multiply_image = ImageChops.multiply(image1, image2) multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=multiply_image, image=multiply_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -307,22 +311,23 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
channel_image = image.getchannel(self.channel) channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=channel_image, image=channel_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -345,22 +350,23 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
converted_image = image.convert(self.mode) converted_image = image.convert(self.mode)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=converted_image, image=converted_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -381,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
blur = ( blur = (
@ -393,16 +399,126 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=blur_image, image=blur_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
"bilinear",
"hamming",
"bicubic",
"lanczos",
]
PIL_RESAMPLING_MAP = {
"nearest": Image.Resampling.NEAREST,
"box": Image.Resampling.BOX,
"bilinear": Image.Resampling.BILINEAR,
"hamming": Image.Resampling.HAMMING,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
"""Resizes an image to specific dimensions"""
# fmt: off
type: Literal["img_resize"] = "img_resize"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
resize_image = image.resize(
(self.width, self.height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
"""Scales an image by a factor"""
# fmt: off
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
height = int(image.height * self.scale_factor)
resize_image = image.resize(
(width, height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -423,7 +539,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@ -433,16 +549,17 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=lerp_image, image=lerp_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -463,7 +580,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
@ -478,16 +595,17 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=ilerp_image, image=ilerp_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ImageType from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationContext, InvocationContext,
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
@ -145,16 +145,17 @@ class InfillColorInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -179,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
infilled = tile_fill_missing( infilled = tile_fill_missing(
@ -189,16 +190,17 @@ class InfillTileInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -216,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
@ -226,16 +228,17 @@ class InfillPatchMatchInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -1,37 +1,36 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Union from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops import einops
import torch import torch
from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from contextlib import ExitStack
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, StableDiffusionGeneratorPipeline, ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor) image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ..services.image_file_storage import ImageType from ...backend.model_management.lora import ModelPatcher
from ..services.model_manager_service import ModelManagerService
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
from .image import ImageCategory, ImageField, ImageOutput from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
from ...backend.model_management.lora import LoRAHelper
class LatentsField(BaseModel): class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
@ -93,10 +92,12 @@ def get_scheduler(
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
@ -171,12 +172,13 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# fmt: on # fmt: on
# Schema customisation # Schema customisation
@ -184,6 +186,10 @@ class TextToLatentsInvocation(BaseInvocation):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents", "image"], "tags": ["latents", "image"],
"type_hints": {
"model": "model",
"control": "control",
}
}, },
} }
@ -244,6 +250,82 @@ class TextToLatentsInvocation(BaseInvocation):
#precision="float16", # TODO: #precision="float16", # TODO:
) )
def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
@ -269,13 +351,19 @@ class TextToLatentsInvocation(BaseInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with LoRAHelper.apply_lora_unet(pipeline.unet, loras): print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback
) )
@ -286,7 +374,6 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image.""" """Generates latents using latents as base image."""
@ -294,13 +381,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image") latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use") strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents"], "tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
}
}, },
} }
@ -315,7 +406,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), **self.unet.unet.dict(),
) )
@ -345,7 +435,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with LoRAHelper.apply_lora_unet(pipeline.unet, loras): with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,
timesteps=timesteps, timesteps=timesteps,
@ -413,7 +503,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image, image=image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -459,6 +549,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents) context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -489,6 +580,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents) context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -513,8 +605,11 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
@ -543,6 +638,6 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = 0.18215 * latents latents = 0.18215 * latents
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)

View File

@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
# fmt: on # fmt: on
class FloatOutput(BaseInvocationOutput):
"""A float output"""
# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig): class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers""" """Adds two numbers"""

View File

@ -3,7 +3,7 @@
from typing import Literal from typing import Literal
from pydantic import Field from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput from .math import IntOutput, FloatOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a) return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
#fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
#fmt: on
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)

View File

@ -2,7 +2,7 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
@ -43,16 +43,17 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.INTERMEDIATE, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -4,7 +4,7 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
@ -45,16 +45,17 @@ class UpscaleInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -5,31 +5,52 @@ from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum): class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The type of an image.""" """The origin of a resource (eg image).
RESULT = "results" - INTERNAL: The resource was created by the application.
UPLOAD = "uploads" - EXTERNAL: The resource was not created by the application.
INTERMEDIATE = "intermediates" This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
INTERNAL = "internal"
"""The resource was created by the application."""
EXTERNAL = "external"
"""The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
class InvalidImageTypeException(ValueError): class InvalidOriginException(ValueError):
"""Raised when a provided value is not a valid ImageType. """Raised when a provided value is not a valid ResourceOrigin.
Subclasses `ValueError`. Subclasses `ValueError`.
""" """
def __init__(self, message="Invalid image type."): def __init__(self, message="Invalid resource origin."):
super().__init__(message) super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum): class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories.""" """The category of an image.
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
- MASK: The image is a mask image.
- CONTROL: The image is a ControlNet control image.
- USER: The image is a user-provide image.
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
"""
GENERAL = "general" GENERAL = "general"
CONTROL = "control" """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
MASK = "mask" MASK = "mask"
"""MASK: The image is a mask image."""
CONTROL = "control"
"""CONTROL: The image is a ControlNet control image."""
USER = "user"
"""USER: The image is a user-provide image."""
OTHER = "other" OTHER = "other"
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
class InvalidImageCategoryException(ValueError): class InvalidImageCategoryException(ValueError):
@ -45,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel): class ImageField(BaseModel):
"""An image field used for passing image objects between invocations""" """An image field used for passing image objects between invocations"""
image_type: ImageType = Field( image_origin: ResourceOrigin = Field(
default=ImageType.RESULT, description="The type of the image" default=ResourceOrigin.INTERNAL, description="The type of the image"
) )
image_name: Optional[str] = Field(default=None, description="The name of the image") image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config: class Config:
schema_extra = {"required": ["image_type", "image_name"]} schema_extra = {"required": ["image_origin", "image_name"]}
class ColorField(BaseModel): class ColorField(BaseModel):
@ -62,3 +83,11 @@ class ColorField(BaseModel):
def tuple(self) -> Tuple[int, int, int, int]: def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a) return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional from typing import Any
from invokeai.app.api.models.images import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException

View File

@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None node_input_field = node_inputs.get(field) or None
return node_input_field return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
if not t1_args:
# t1 is a single type
return t1 in t2_args
else:
# t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t):
t_args = get_args(t)
# If the type is a List
if get_origin(t) is list:
return True
# If the type is a Union
elif t_args:
# Check if any of the types in the Union is a List
for arg in t_args:
if get_origin(arg) is list:
return True
return False
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type: if not from_type:
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type): if to_type in get_args(from_type):
return True return True
if not issubclass(from_type, to_type): # if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False return False
else: else:
return False return False
@ -694,7 +724,11 @@ class Graph(BaseModel):
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists # Verify that all outputs are lists
if not all((get_origin(f) == list for f in output_fields)): # if not all((get_origin(f) == list for f in output_fields)):
# return False
# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
return False return False
# Verify that all outputs match the input type (are a base class or the same class) # Verify that all outputs match the input type (are a base class or the same class)

View File

@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
@abstractmethod @abstractmethod
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the internal path to an image or thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass pass
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True) Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath? # TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType: for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_type)).mkdir( Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_origin, image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_origin, image_name)
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG") image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path) thumbnail_image.save(thumbnail_path)
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e: except Exception as e:
raise ImageFileSaveException from e raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try: try:
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename) image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path): if os.path.exists(image_path):
send2trash(image_path) send2trash(image_path)
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path] del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True) thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
if os.path.exists(thumbnail_path): if os.path.exists(thumbnail_path):
send2trash(thumbnail_path) send2trash(thumbnail_path)
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
# strip out any relative path shenanigans # strip out any relative path shenanigans
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail: if thumbnail:
thumbnail_name = get_thumbnail_name(basename) thumbnail_name = get_thumbnail_name(basename)
path = os.path.join( path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name self.__output_folder, image_origin, "thumbnails", thumbnail_name
) )
else: else:
path = os.path.join(self.__output_folder, image_type, basename) path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path) abspath = os.path.abspath(path)

View File

@ -1,20 +1,35 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Optional, cast from typing import Generic, Optional, TypeVar, cast
import sqlite3 import sqlite3
import threading import threading
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageRecordChanges,
deserialize_image_record, deserialize_image_record,
) )
from invokeai.app.services.item_storage import PaginatedResults
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
# fmt: off
items: list[T] = Field(description="Items")
offset: int = Field(description="Offset from which to retrieve items")
limit: int = Field(description="Limit of items to get")
total: int = Field(description="Total number of items in result")
# fmt: on
# TODO: Should these excpetions subclass existing python exceptions? # TODO: Should these excpetions subclass existing python exceptions?
@ -45,25 +60,36 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method # TODO: Implement an `update()` method
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord: def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
pass
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageRecord]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass
# TODO: The database has a nullable `deleted_at` column, currently unused. # TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage. # Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image record.""" """Deletes an image record."""
pass pass
@ -71,13 +97,14 @@ class ImageRecordStorageBase(ABC):
def save( def save(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
width: int, width: int,
height: int, height: int,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass
@ -91,7 +118,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
@ -117,7 +143,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE TABLE IF NOT EXISTS images ( CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY, image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility -- This is an enum in python, unrestricted string here for flexibility
image_type TEXT NOT NULL, image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility -- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL, image_category TEXT NOT NULL,
width INTEGER NOT NULL, width INTEGER NOT NULL,
@ -125,9 +151,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT, session_id TEXT,
node_id TEXT, node_id TEXT,
metadata TEXT, metadata TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger -- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused -- Soft delete, currently unused
deleted_at DATETIME deleted_at DATETIME
); );
@ -142,7 +169,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
) )
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
""" """
) )
self._cursor.execute( self._cursor.execute(
@ -169,7 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -193,38 +222,110 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result)) return deserialize_image_record(dict(result))
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageRecord]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( # Manually build two queries - one for the count, one for the records
f"""--sql
SELECT * FROM images count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
WHERE image_type = ? AND image_category = ? images_query = f"""SELECT * FROM images WHERE 1=1\n"""
ORDER BY created_at DESC
LIMIT ? OFFSET ?; query_conditions = ""
""", query_params = []
(image_type.value, image_category.value, per_page, page * per_page),
if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n"""
query_params.append(image_origin.value)
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
) )
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n"""
query_params.append(is_intermediate)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result)) images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute( # Set up and execute the count query, without pagination
"""--sql count_query += query_conditions + ";"
SELECT count(*) FROM images count_params = query_params.copy()
WHERE image_type = ? AND image_category = ? self._cursor.execute(count_query, count_params)
""",
(image_type.value, image_category.value),
)
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -232,13 +333,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
pageCount = int(count / per_page) + 1 return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
return PaginatedResults(
items=images, page=page, pages=pageCount, per_page=per_page, total=count
) )
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -258,13 +357,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def save( def save(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
session_id: Optional[str], session_id: Optional[str],
width: int, width: int,
height: int, height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
@ -275,25 +375,27 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""--sql """--sql
INSERT OR IGNORE INTO images ( INSERT OR IGNORE INTO images (
image_name, image_name,
image_type, image_origin,
image_category, image_category,
width, width,
height, height,
node_id, node_id,
session_id, session_id,
metadata metadata,
is_intermediate
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
image_type.value, image_origin.value,
image_category.value, image_category.value,
width, width,
height, height,
node_id, node_id,
session_id, session_id,
metadata_json, metadata_json,
is_intermediate,
), ),
) )
self._conn.commit() self._conn.commit()

View File

@ -1,14 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
InvalidImageCategoryException, InvalidImageCategoryException,
InvalidImageTypeException, InvalidOriginException,
) )
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
@ -16,10 +15,12 @@ from invokeai.app.services.image_record_storage import (
ImageRecordNotFoundException, ImageRecordNotFoundException,
ImageRecordSaveException, ImageRecordSaveException,
ImageRecordStorageBase, ImageRecordStorageBase,
OffsetPaginatedResults,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageDTO, ImageDTO,
ImageRecordChanges,
image_record_to_dto, image_record_to_dto,
) )
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
@ -30,8 +31,8 @@ from invokeai.app.services.image_file_storage import (
) )
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -44,32 +45,42 @@ class ImageServiceABC(ABC):
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None, intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
pass pass
@abstractmethod @abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod @abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str: def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -80,7 +91,7 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets an image's or thumbnail's URL.""" """Gets an image's or thumbnail's URL."""
pass pass
@ -88,16 +99,17 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageDTO]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs.""" """Gets a paginated list of image DTOs."""
pass pass
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_origin: ResourceOrigin, image_name: str):
"""Deletes an image.""" """Deletes an image."""
pass pass
@ -110,6 +122,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"] graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__( def __init__(
@ -119,6 +132,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self.records = image_record_storage self.records = image_record_storage
@ -126,6 +140,7 @@ class ImageServiceDependencies:
self.metadata = metadata self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
@ -139,6 +154,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self._services = ImageServiceDependencies( self._services = ImageServiceDependencies(
@ -147,29 +163,26 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
url=url, url=url,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
if image_type not in ImageType: if image_origin not in ResourceOrigin:
raise InvalidImageTypeException raise InvalidOriginException
if image_category not in ImageCategory: if image_category not in ImageCategory:
raise InvalidImageCategoryException raise InvalidImageCategoryException
image_name = self._create_image_name( image_name = self._services.names.create_image_name()
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
metadata = self._get_metadata(session_id, node_id) metadata = self._get_metadata(session_id, node_id)
@ -180,10 +193,12 @@ class ImageService(ImageServiceABC):
created_at = self._services.records.save( created_at = self._services.records.save(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id, session_id=session_id,
@ -191,21 +206,21 @@ class ImageService(ImageServiceABC):
) )
self._services.files.save( self._services.files.save(
image_type=image_type, image_origin=image_origin,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_type, image_name) image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url( thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True image_origin, image_name, True
) )
return ImageDTO( return ImageDTO(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
@ -217,6 +232,7 @@ class ImageService(ImageServiceABC):
created_at=created_at, created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None, deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO # Extra non-nullable fields for DTO
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -231,9 +247,25 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem saving image record and file") self._services.logger.error("Problem saving image record and file")
raise e raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try: try:
return self._services.files.get(image_type, image_name) self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_origin, image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -241,9 +273,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file") self._services.logger.error("Problem getting image file")
raise e raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_type, image_name) return self._services.records.get(image_origin, image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -251,14 +283,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record") self._services.logger.error("Problem getting image record")
raise e raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_type, image_name) image_record = self._services.records.get(image_origin, image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_type, image_name), self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_type, image_name, True), self._services.urls.get_image_url(image_origin, image_name, True),
) )
return image_dto return image_dto
@ -270,10 +302,10 @@ class ImageService(ImageServiceABC):
raise e raise e
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
try: try:
return self._services.files.get_path(image_type, image_name, thumbnail) return self._services.files.get_path(image_origin, image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -286,57 +318,58 @@ class ImageService(ImageServiceABC):
raise e raise e
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
try: try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail) return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
def get_many( def get_many(
self, self,
image_type: ImageType, offset: int = 0,
image_category: ImageCategory, limit: int = 10,
page: int = 0, image_origin: Optional[ResourceOrigin] = None,
per_page: int = 10, categories: Optional[list[ImageCategory]] = None,
) -> PaginatedResults[ImageDTO]: is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self._services.records.get_many( results = self._services.records.get_many(
image_type, offset,
image_category, limit,
page, image_origin,
per_page, categories,
is_intermediate,
) )
image_dtos = list( image_dtos = list(
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, r,
self._services.urls.get_image_url(image_type, r.image_name), self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url( self._services.urls.get_image_url(
image_type, r.image_name, True r.image_origin, r.image_name, True
), ),
), ),
results.items, results.items,
) )
) )
return PaginatedResults[ImageDTO]( return OffsetPaginatedResults[ImageDTO](
items=image_dtos, items=image_dtos,
page=results.page, offset=results.offset,
pages=results.pages, limit=results.limit,
per_page=results.per_page,
total=results.total, total=results.total,
) )
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs") self._services.logger.error("Problem getting paginated image DTOs")
raise e raise e
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_origin: ResourceOrigin, image_name: str):
try: try:
self._services.files.delete(image_type, image_name) self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_type, image_name) self._services.records.delete(image_origin, image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise
@ -347,21 +380,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem deleting image record and file") self._services.logger.error("Problem deleting image record and file")
raise e raise e
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata( def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]: ) -> Union[ImageMetadata, None]:

View File

@ -1,7 +1,7 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The type of the image.""" """The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.") image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image.""" """The category of the image."""
width: int = Field(description="The width of the image in px.") width: int = Field(description="The width of the image in px.")
@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
description="The deleted timestamp of the image." description="The deleted timestamp of the image."
) )
"""The deleted timestamp of the image.""" """The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
session_id: Optional[str] = Field( session_id: Optional[str] = Field(
default=None, default=None,
description="The session ID that generated this image, if it is a generated image.", description="The session ID that generated this image, if it is a generated image.",
@ -48,13 +50,37 @@ class ImageRecord(BaseModel):
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag
"""
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModel): class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail.""" """The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The type of the image.""" """The origin of the image."""
image_url: str = Field(description="The URL of the image.") image_url: str = Field(description="The URL of the image.")
"""The URL of the image.""" """The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.") thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
@ -84,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
)
image_category = ImageCategory( image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value) image_dict.get("image_category", ImageCategory.GENERAL.value)
) )
@ -95,6 +123,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at = image_dict.get("created_at", get_iso_timestamp()) created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata") raw_metadata = image_dict.get("metadata")
@ -105,7 +134,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
@ -115,4 +144,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate,
) )

View File

@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
import uuid
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
IMAGE = "image"
LATENT = "latent"
class NameServiceBase(ABC):
"""Low-level service responsible for naming resources (images, latents, etc)."""
# TODO: Add customizable naming schemes
@abstractmethod
def create_image_name(self) -> str:
"""Creates a name for an image."""
pass
class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs."""
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename

View File

@ -1,7 +1,7 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name from invokeai.app.util.thumbnails import get_thumbnail_name
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
@abstractmethod @abstractmethod
def get_image_url( def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the URL for an image or thumbnail.""" """Gets the URL for an image or thumbnail."""
pass pass
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
self._base_url = base_url self._base_url = base_url
def get_image_url( def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail: if thumbnail:
return ( return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
) )
return f"{self._base_url}/images/{image_type.value}/{image_basename}" return f"{self._base_url}/images/{image_origin.value}/{image_basename}"

View File

@ -1,5 +1,5 @@
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator

View File

@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self, def __init__(self,
model_info: dict, model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
**kwargs,
): ):
self.model_info=model_info self.model_info=model_info
self.params=params self.params=params
self.kwargs = kwargs
def generate(self, def generate(self,
prompt: str='', prompt: str='',
@ -120,7 +122,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
) )
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class() gen_class = self._generator_class()
generator = gen_class(model, self.params.precision) generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0: if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'), generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'), generator_args.get('variation_amount'),
@ -275,7 +277,7 @@ class Generator:
precision: str precision: str
model: DiffusionPipeline model: DiffusionPipeline
def __init__(self, model: DiffusionPipeline, precision: str): def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.model = model self.model = model
self.precision = precision self.precision = precision
self.seed = None self.seed = None

View File

@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
import PIL.Image import PIL.Image
import torch import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from ..stable_diffusion import ( from ..stable_diffusion import (
ConditioningData, ConditioningData,
PostprocessingSettings, PostprocessingSettings,
@ -13,8 +17,13 @@ from .base import Generator
class Txt2Img(Generator): class Txt2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision,
super().__init__(model, precision) control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
**kwargs):
self.control_model = control_model
if isinstance(self.control_model, list):
self.control_model = MultiControlNetModel(self.control_model)
super().__init__(model, precision, **kwargs)
@torch.no_grad() @torch.no_grad()
def get_make_image( def get_make_image(
@ -42,9 +51,12 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
self.perlin = perlin self.perlin = perlin
control_image = kwargs.get("control_image", None)
do_classifier_free_guidance = cfg_scale > 1.0
# noinspection PyTypeChecker # noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.control_model = self.control_model
pipeline.scheduler = sampler pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
@ -61,6 +73,37 @@ class Txt2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(self.control_model, ControlNetModel):
control_image = pipeline.prepare_control_image(
image=control_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
elif isinstance(self.control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = pipeline.prepare_control_image(
image=image_,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
images.append(image_)
control_image = images
kwargs["control_image"] = control_image
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image: def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
@ -68,6 +111,7 @@ class Txt2Img(Generator):
num_inference_steps=steps, num_inference_steps=steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback, callback=step_callback,
**kwargs,
) )
if ( if (

View File

@ -2,23 +2,29 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import math
import secrets import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
import einops import einops
import PIL.Image import PIL.Image
import numpy as np
from accelerate.utils import set_seed from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from compel import EmbeddingsProvider from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
) )
@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
) )
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -68,10 +75,10 @@ class AddsMaskLatents:
initial_image_latents: torch.Tensor initial_image_latents: torch.Tensor
def __call__( def __call__(
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
model_input = self.add_mask_channels(latents) model_input = self.add_mask_channels(latents)
return self.forward(model_input, t, text_embeddings) return self.forward(model_input, t, text_embeddings, **kwargs)
def add_mask_channels(self, latents): def add_mask_channels(self, latents):
batch_size = latents.size(0) batch_size = latents.size(0)
@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?") raise AssertionError("why was that an empty generator?")
return result return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32", precision: str = "float32",
control_model: ControlNetModel = None,
): ):
super().__init__( super().__init__(
vae, vae,
@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent( self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True self.unet, self._unet_forward, is_running_diffusers=True
@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels) self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
""" """
@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu') scheduler_device = torch.device('cpu')
@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
control_data=control_data,
**kwargs,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step( step_output = self.step(
@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,
)
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, latent_model_input,
@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t, t,
text_embeddings, text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None, cross_attention_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
): ):
"""predict the noise residual""" """predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4: if is_inpainting_model(self.unet) and latents.size(1) == 4:
@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# First three args should be positional, not keywords, so torch hooks can see them. # First three args should be positional, not keywords, so torch hooks can see them.
return self.unet( return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
).sample ).sample
def img2img_from_embeddings( def img2img_from_embeddings(
@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
@staticmethod
def prepare_control_image(
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image

View File

@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: Optional[int] = None, step_index: Optional[int] = None,
total_step_count: Optional[int] = None, total_step_count: Optional[int] = None,
**kwargs,
): ):
""" """
:param x: current latents :param x: current latents
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning: if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
elif wants_cross_attention_control: elif wants_cross_attention_control:
( (
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
else: else:
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
combined_next_x = self._combine( combined_next_x = self._combine(
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path # fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning: torch.Tensor, unconditioning: torch.Tensor,
conditioning: torch.Tensor, conditioning: torch.Tensor,
**kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug. # prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone() conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict) assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
else: else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback( unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
).chunk(2) ).chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
if self.is_running_diffusers: if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers( return self._apply_cross_attention_controlled_conditioning__diffusers(
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
else: else:
return self._apply_cross_attention_controlled_conditioning__compvis( return self._apply_cross_attention_controlled_conditioning__compvis(
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
def _apply_cross_attention_controlled_conditioning__diffusers( def _apply_cross_attention_controlled_conditioning__diffusers(
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
# do requested cross attention types for conditioning (positive prompt) # do requested cross attention types for conditioning (positive prompt)
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do) # print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do: for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type) context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
context.clear_requests(cleanup=False) context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
self.conditioning.cross_attention_control_args.edited_conditioning self.conditioning.cross_attention_control_args.edited_conditioning
) )
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning x, sigma, edited_conditioning, **kwargs,
) )
context.clear_requests(cleanup=True) context.clear_requests(cleanup=True)

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-5fb14ef2.js"></script> <script type="module" crossorigin src="./assets/index-251c2c6e.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -122,7 +122,9 @@
"noImagesInGallery": "No Images In Gallery", "noImagesInGallery": "No Images In Gallery",
"deleteImage": "Delete Image", "deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored." "deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts", "keyboardShortcuts": "Keyboard Shortcuts",
@ -452,6 +454,8 @@
"height": "Height", "height": "Height",
"scheduler": "Scheduler", "scheduler": "Scheduler",
"seed": "Seed", "seed": "Seed",
"boundingBoxWidth": "Bounding Box Width",
"boundingBoxHeight": "Bounding Box Height",
"imageToImage": "Image to Image", "imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed", "randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle Seed", "shuffle": "Shuffle Seed",
@ -524,7 +528,7 @@
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",
"displayInProgress": "Display In-Progress Images", "displayInProgress": "Display Progress Images",
"saveSteps": "Save images every n steps", "saveSteps": "Save images every n steps",
"confirmOnDelete": "Confirm On Delete", "confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons", "displayHelpIcons": "Display Help Icons",
@ -564,6 +568,8 @@
"canvasMerged": "Canvas Merged", "canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image", "sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas", "sentToUnifiedCanvas": "Sent to Unified Canvas",
"parameterSet": "Parameter set",
"parameterNotSet": "Parameter not set",
"parametersSet": "Parameters Set", "parametersSet": "Parameters Set",
"parametersNotSet": "Parameters Not Set", "parametersNotSet": "Parameters Not Set",
"parametersNotSetDesc": "No metadata found for this image.", "parametersNotSetDesc": "No metadata found for this image.",

View File

@ -101,7 +101,8 @@
"serialize-error": "^11.0.0", "serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0", "socket.io-client": "^4.6.0",
"use-image": "^1.1.0", "use-image": "^1.1.0",
"uuid": "^9.0.0" "uuid": "^9.0.0",
"zod": "^3.21.4"
}, },
"peerDependencies": { "peerDependencies": {
"@chakra-ui/cli": "^2.4.0", "@chakra-ui/cli": "^2.4.0",

View File

@ -122,7 +122,9 @@
"noImagesInGallery": "No Images In Gallery", "noImagesInGallery": "No Images In Gallery",
"deleteImage": "Delete Image", "deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored." "deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts", "keyboardShortcuts": "Keyboard Shortcuts",
@ -452,6 +454,8 @@
"height": "Height", "height": "Height",
"scheduler": "Scheduler", "scheduler": "Scheduler",
"seed": "Seed", "seed": "Seed",
"boundingBoxWidth": "Bounding Box Width",
"boundingBoxHeight": "Bounding Box Height",
"imageToImage": "Image to Image", "imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed", "randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle Seed", "shuffle": "Shuffle Seed",
@ -524,7 +528,7 @@
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",
"displayInProgress": "Display In-Progress Images", "displayInProgress": "Display Progress Images",
"saveSteps": "Save images every n steps", "saveSteps": "Save images every n steps",
"confirmOnDelete": "Confirm On Delete", "confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons", "displayHelpIcons": "Display Help Icons",
@ -564,6 +568,8 @@
"canvasMerged": "Canvas Merged", "canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image", "sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas", "sentToUnifiedCanvas": "Sent to Unified Canvas",
"parameterSet": "Parameter set",
"parameterNotSet": "Parameter not set",
"parametersSet": "Parameters Set", "parametersSet": "Parameters Set",
"parametersNotSet": "Parameters Not Set", "parametersNotSet": "Parameters Not Set",
"parametersNotSetDesc": "No metadata found for this image.", "parametersNotSetDesc": "No metadata found for this image.",

View File

@ -21,25 +21,11 @@ export const SCHEDULERS = [
export type Scheduler = (typeof SCHEDULERS)[number]; export type Scheduler = (typeof SCHEDULERS)[number];
export const isScheduler = (x: string): x is Scheduler =>
SCHEDULERS.includes(x as Scheduler);
// Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
(_x, i) => (i + 1) * 64
);
// Valid image heights
export const HEIGHTS: Array<number> = Array.from(Array(64)).map(
(_x, i) => (i + 1) * 64
);
// Valid upscaling levels // Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [ export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 }, { key: '2x', value: 2 },
{ key: '4x', value: 4 }, { key: '4x', value: 4 },
]; ];
export const NUMPY_RAND_MIN = 0; export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647; export const NUMPY_RAND_MAX = 2147483647;

View File

@ -1,7 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist'; import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
@ -22,11 +20,9 @@ const serializationDenylist: {
models: modelsPersistDenylist, models: modelsPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
results: resultsPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,
// config: configPersistDenyList, // config: configPersistDenyList,
ui: uiPersistDenylist, ui: uiPersistDenylist,
uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist, // hotkeys: hotkeysPersistDenylist,
}; };

View File

@ -1,7 +1,6 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialResultsState } from 'features/gallery/store/resultsSlice'; import { initialImagesState } from 'features/gallery/store/imagesSlice';
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
@ -24,12 +23,11 @@ const initialStates: {
models: initialModelsState, models: initialModelsState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
results: initialResultsState,
system: initialSystemState, system: initialSystemState,
config: initialConfigState, config: initialConfigState,
ui: initialUIState, ui: initialUIState,
uploads: initialUploadsState,
hotkeys: initialHotkeysState, hotkeys: initialHotkeysState,
images: initialImagesState,
}; };
export const unserialize: UnserializeFunction = (data, key) => { export const unserialize: UnserializeFunction = (data, key) => {

View File

@ -7,5 +7,6 @@ export const actionsDenylist = [
'canvas/setBoundingBoxDimensions', 'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing', 'canvas/setIsDrawing',
'canvas/addPointToCurrentLine', 'canvas/addPointToCurrentLine',
'socket/generatorProgress', 'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
]; ];

View File

@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store'; import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete'; import {
import { addImageUploadedListener } from './listeners/imageUploaded'; addImageUploadedFulfilledListener,
import { addRequestedImageDeletionListener } from './listeners/imageDeleted'; addImageUploadedRejectedListener,
} from './listeners/imageUploaded';
import {
addImageDeletedFulfilledListener,
addImageDeletedPendingListener,
addImageDeletedRejectedListener,
addRequestedImageDeletionListener,
} from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
@ -19,6 +26,50 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
import {
addImageMetadataReceivedFulfilledListener,
addImageMetadataReceivedRejectedListener,
} from './listeners/imageMetadataReceived';
import {
addImageUrlsReceivedFulfilledListener,
addImageUrlsReceivedRejectedListener,
} from './listeners/imageUrlsReceived';
import {
addSessionCreatedFulfilledListener,
addSessionCreatedPendingListener,
addSessionCreatedRejectedListener,
} from './listeners/sessionCreated';
import {
addSessionInvokedFulfilledListener,
addSessionInvokedPendingListener,
addSessionInvokedRejectedListener,
} from './listeners/sessionInvoked';
import {
addSessionCanceledFulfilledListener,
addSessionCanceledPendingListener,
addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled';
import {
addImageUpdatedFulfilledListener,
addImageUpdatedRejectedListener,
} from './listeners/imageUpdated';
import {
addReceivedPageOfImagesFulfilledListener,
addReceivedPageOfImagesRejectedListener,
} from './listeners/receivedPageOfImages';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -38,17 +89,87 @@ export type AppListenerEffect = ListenerEffect<
AppDispatch AppDispatch
>; >;
addImageUploadedListener(); // Image uploaded
addInitialImageSelectedListener(); addImageUploadedFulfilledListener();
addImageResultReceivedListener(); addImageUploadedRejectedListener();
addRequestedImageDeletionListener();
// Image updated
addImageUpdatedFulfilledListener();
addImageUpdatedRejectedListener();
// Image selected
addInitialImageSelectedListener();
// Image deleted
addRequestedImageDeletionListener();
addImageDeletedPendingListener();
addImageDeletedFulfilledListener();
addImageDeletedRejectedListener();
// Image metadata
addImageMetadataReceivedFulfilledListener();
addImageMetadataReceivedRejectedListener();
// Image URLs
addImageUrlsReceivedFulfilledListener();
addImageUrlsReceivedRejectedListener();
// User Invoked
addUserInvokedCanvasListener(); addUserInvokedCanvasListener();
addUserInvokedNodesListener(); addUserInvokedNodesListener();
addUserInvokedTextToImageListener(); addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener(); addUserInvokedImageToImageListener();
addSessionReadyToInvokeListener();
// Canvas actions
addCanvasSavedToGalleryListener(); addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener(); addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener(); addCanvasCopiedToClipboardListener();
addCanvasMergedListener(); addCanvasMergedListener();
addStagingAreaImageSavedListener();
addCommitStagingAreaImageListener();
/**
* Socket.IO Events - these handle SIO events directly and pass on internal application actions.
* We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't
* actually be handled at all.
*
* For example, we don't want to respond to progress events for canceled sessions. To avoid
* duplicating the logic to determine if an event should be responded to, we handle all of that
* "is this session canceled?" logic in these listeners.
*
* The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress`
* action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress`
* action that is handled by reducers in slices.
*/
addGeneratorProgressListener();
addGraphExecutionStateCompleteListener();
addInvocationCompleteListener();
addInvocationErrorListener();
addInvocationStartedListener();
addSocketConnectedListener();
addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
// Session Created
addSessionCreatedPendingListener();
addSessionCreatedFulfilledListener();
addSessionCreatedRejectedListener();
// Session Invoked
addSessionInvokedPendingListener();
addSessionInvokedFulfilledListener();
addSessionInvokedRejectedListener();
// Session Canceled
addSessionCanceledPendingListener();
addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener();
// Fetching images
addReceivedPageOfImagesFulfilledListener();
addReceivedPageOfImagesRejectedListener();
// Gallery
addImageCategoriesChangedListener();

View File

@ -0,0 +1,42 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
import { sessionCanceled } from 'services/thunks/session';
const moduleLog = log.child({ namespace: 'canvas' });
export const addCommitStagingAreaImageListener = () => {
startAppListening({
actionCreator: commitStagingAreaImage,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const { sessionId, isProcessing } = state.system;
const canvasSessionId = action.payload;
if (!isProcessing) {
// Only need to cancel if we are processing
return;
}
if (!canvasSessionId) {
moduleLog.debug('No canvas session, skipping cancel');
return;
}
if (canvasSessionId !== sessionId) {
moduleLog.debug(
{
data: {
canvasSessionId,
sessionId,
},
},
'Canvas session does not match global session, skipping cancel'
);
return;
}
dispatch(sessionCanceled({ sessionId }));
},
});
};

View File

@ -52,10 +52,11 @@ export const addCanvasMergedListener = () => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
imageType: 'intermediates',
formData: { formData: {
file: new File([blob], filename, { type: 'image/png' }), file: new File([blob], filename, { type: 'image/png' }),
}, },
imageCategory: 'general',
isIntermediate: true,
}) })
); );
@ -65,7 +66,7 @@ export const addCanvasMergedListener = () => {
action.meta.arg.formData.file.name === filename action.meta.arg.formData.file.name === filename
); );
const mergedCanvasImage = payload.response; const mergedCanvasImage = payload;
dispatch( dispatch(
setMergedCanvas({ setMergedCanvas({

View File

@ -4,16 +4,18 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { imageUpserted } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => { export const addCanvasSavedToGalleryListener = () => {
startAppListening({ startAppListening({
actionCreator: canvasSavedToGallery, actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState, take }) => {
const state = getState(); const state = getState();
const blob = await getBaseLayerBlob(state); const blob = await getBaseLayerBlob(state, true);
if (!blob) { if (!blob) {
moduleLog.error('Problem getting base layer blob'); moduleLog.error('Problem getting base layer blob');
@ -27,14 +29,25 @@ export const addCanvasSavedToGalleryListener = () => {
return; return;
} }
const filename = `mergedCanvas_${uuidv4()}.png`;
dispatch( dispatch(
imageUploaded({ imageUploaded({
imageType: 'results',
formData: { formData: {
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), file: new File([blob], filename, { type: 'image/png' }),
}, },
imageCategory: 'general',
isIntermediate: false,
}) })
); );
const [{ payload: uploadedImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === filename
);
dispatch(imageUpserted(uploadedImageDTO));
}, },
}); });
}; };

View File

@ -0,0 +1,24 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedPageOfImages } from 'services/thunks/image';
import {
imageCategoriesChanged,
selectFilteredImagesAsArray,
} from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'gallery' });
export const addImageCategoriesChangedListener = () => {
startAppListening({
actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => {
const filteredImagesCount = selectFilteredImagesAsArray(
getState()
).length;
if (!filteredImagesCount) {
dispatch(receivedPageOfImages());
}
},
});
};

View File

@ -4,9 +4,18 @@ import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
imageRemoved,
imagesAdapter,
selectImagesEntities,
selectImagesIds,
} from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
/**
* Called when the user requests an image deletion
*/
export const addRequestedImageDeletionListener = () => { export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: requestedImageDeletion, actionCreator: requestedImageDeletion,
@ -17,24 +26,20 @@ export const addRequestedImageDeletionListener = () => {
return; return;
} }
const { image_name, image_type } = image; const { image_name, image_origin } = image;
if (image_type !== 'uploads' && image_type !== 'results') { const state = getState();
moduleLog.warn({ data: image }, `Invalid image type ${image_type}`); const selectedImage = state.gallery.selectedImage;
return;
}
const selectedImageName = getState().gallery.selectedImage?.image_name; if (selectedImage && selectedImage.image_name === image_name) {
const ids = selectImagesIds(state);
const entities = selectImagesEntities(state);
if (selectedImageName === image_name) { const deletedImageIndex = ids.findIndex(
const allIds = getState()[image_type].ids;
const allEntities = getState()[image_type].entities;
const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === image_name (result) => result.toString() === image_name
); );
const filteredIds = allIds.filter((id) => id.toString() !== image_name); const filteredIds = ids.filter((id) => id.toString() !== image_name);
const newSelectedImageIndex = clamp( const newSelectedImageIndex = clamp(
deletedImageIndex, deletedImageIndex,
@ -44,7 +49,7 @@ export const addRequestedImageDeletionListener = () => {
const newSelectedImageId = filteredIds[newSelectedImageIndex]; const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = allEntities[newSelectedImageId]; const newSelectedImage = entities[newSelectedImageId];
if (newSelectedImageId) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage)); dispatch(imageSelected(newSelectedImage));
@ -53,7 +58,52 @@ export const addRequestedImageDeletionListener = () => {
} }
} }
dispatch(imageDeleted({ imageName: image_name, imageType: image_type })); dispatch(imageRemoved(image_name));
dispatch(
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
);
},
});
};
/**
* Called when the actual delete request is sent to the server
*/
export const addImageDeletedPendingListener = () => {
startAppListening({
actionCreator: imageDeleted.pending,
effect: (action, { dispatch, getState }) => {
const { imageName, imageOrigin } = action.meta.arg;
// Preemptively remove the image from the gallery
imagesAdapter.removeOne(getState().images, imageName);
},
});
};
/**
* Called on successful delete
*/
export const addImageDeletedFulfilledListener = () => {
startAppListening({
actionCreator: imageDeleted.fulfilled,
effect: (action, { dispatch, getState }) => {
moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted');
},
});
};
/**
* Called on failed delete
*/
export const addImageDeletedRejectedListener = () => {
startAppListening({
actionCreator: imageDeleted.rejected,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
{ data: { image: action.meta.arg } },
'Unable to delete image'
);
}, },
}); });
}; };

View File

@ -0,0 +1,33 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'image' });
export const addImageMetadataReceivedFulfilledListener = () => {
startAppListening({
actionCreator: imageMetadataReceived.fulfilled,
effect: (action, { getState, dispatch }) => {
const image = action.payload;
if (image.is_intermediate) {
// No further actions needed for intermediate images
return;
}
moduleLog.debug({ data: { image } }, 'Image metadata received');
dispatch(imageUpserted(image));
},
});
};
export const addImageMetadataReceivedRejectedListener = () => {
startAppListening({
actionCreator: imageMetadataReceived.rejected,
effect: (action, { getState, dispatch }) => {
moduleLog.debug(
{ data: { image: action.meta.arg } },
'Problem receiving image metadata'
);
},
});
};

View File

@ -0,0 +1,26 @@
import { startAppListening } from '..';
import { imageUpdated } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ namespace: 'image' });
export const addImageUpdatedFulfilledListener = () => {
startAppListening({
actionCreator: imageUpdated.fulfilled,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
{ oldImage: action.meta.arg, updatedImage: action.payload },
'Image updated'
);
},
});
};
export const addImageUpdatedRejectedListener = () => {
startAppListening({
actionCreator: imageUpdated.rejected,
effect: (action, { dispatch }) => {
moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed');
},
});
};

View File

@ -1,44 +1,46 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions'; import { log } from 'app/logging/useLogger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { imageUpserted } from 'features/gallery/store/imagesSlice';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
export const addImageUploadedListener = () => { const moduleLog = log.child({ namespace: 'image' });
export const addImageUploadedFulfilledListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> => actionCreator: imageUploaded.fulfilled,
imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates',
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { response: image } = action.payload; const image = action.payload;
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
if (action.payload.is_intermediate) {
// No further actions needed for intermediate images
return;
}
const state = getState(); const state = getState();
if (isUploadsImageDTO(image)) { dispatch(imageUpserted(image));
dispatch(uploadAdded(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
},
if (state.gallery.shouldAutoSwitchToNewImages) { });
dispatch(imageSelected(image)); };
}
export const addImageUploadedRejectedListener = () => {
if (action.meta.arg.activeTabName === 'img2img') { startAppListening({
dispatch(initialImageSelected(image)); actionCreator: imageUploaded.rejected,
} effect: (action, { dispatch }) => {
const { formData, ...rest } = action.meta.arg;
if (action.meta.arg.activeTabName === 'unifiedCanvas') { const sanitizedData = { arg: { ...rest, formData: { file: '<Blob>' } } };
dispatch(setInitialCanvasImage(image)); moduleLog.error({ data: sanitizedData }, 'Image upload failed');
} dispatch(
} addToast({
title: 'Image Upload Failed',
if (isResultsImageDTO(image)) { description: action.error.message,
dispatch(resultAdded(image)); status: 'error',
} })
);
}, },
}); });
}; };

View File

@ -0,0 +1,38 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageUrlsReceived } from 'services/thunks/image';
import { imagesAdapter } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'image' });
export const addImageUrlsReceivedFulfilledListener = () => {
startAppListening({
actionCreator: imageUrlsReceived.fulfilled,
effect: (action, { getState, dispatch }) => {
const image = action.payload;
moduleLog.debug({ data: { image } }, 'Image URLs received');
const { image_name, image_url, thumbnail_url } = image;
imagesAdapter.updateOne(getState().images, {
id: image_name,
changes: {
image_url,
thumbnail_url,
},
});
},
});
};
export const addImageUrlsReceivedRejectedListener = () => {
startAppListening({
actionCreator: imageUrlsReceived.rejected,
effect: (action, { getState, dispatch }) => {
moduleLog.debug(
{ data: { image: action.meta.arg } },
'Problem getting image URLs'
);
},
});
};

View File

@ -1,6 +1,4 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -9,7 +7,7 @@ import {
isImageDTO, isImageDTO,
} from 'features/parameters/store/actions'; } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { ImageDTO } from 'services/api'; import { selectImagesById } from 'features/gallery/store/imagesSlice';
export const addInitialImageSelectedListener = () => { export const addInitialImageSelectedListener = () => {
startAppListening({ startAppListening({
@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => {
return; return;
} }
const { image_name, image_type } = action.payload; const imageName = action.payload;
const image = selectImagesById(getState(), imageName);
let image: ImageDTO | undefined;
const state = getState();
if (image_type === 'results') {
image = selectResultsById(state, image_name);
} else if (image_type === 'uploads') {
image = selectUploadsById(state, image_name);
}
if (!image) { if (!image) {
dispatch( dispatch(

View File

@ -1,62 +0,0 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
imageMetadataReceived,
imageUrlsReceived,
} from 'services/thunks/image';
import { startAppListening } from '..';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => {
startAppListening({
predicate: (action) => {
if (
invocationComplete.match(action) &&
isImageOutput(action.payload.data.result)
) {
return true;
}
return false;
},
effect: async (action, { getState, dispatch, take }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_type } = result.image;
dispatch(
imageUrlsReceived({ imageName: image_name, imageType: image_type })
);
dispatch(
imageMetadataReceived({
imageName: image_name,
imageType: image_type,
})
);
// Handle canvas image
if (
graph_execution_state_id ===
getState().canvas.layerState.stagingArea.sessionId
) {
const [{ payload: image }] = await take(
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
dispatch(addImageToStagingArea(image));
}
}
},
});
};

View File

@ -0,0 +1,33 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { serializeError } from 'serialize-error';
import { receivedPageOfImages } from 'services/thunks/image';
const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedPageOfImagesFulfilledListener = () => {
startAppListening({
actionCreator: receivedPageOfImages.fulfilled,
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { payload: action.payload } },
`Received ${page.items.length} images`
);
},
});
};
export const addReceivedPageOfImagesRejectedListener = () => {
startAppListening({
actionCreator: receivedPageOfImages.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
moduleLog.debug(
{ data: { error: serializeError(action.payload) } },
'Problem receiving images'
);
}
},
});
};

View File

@ -0,0 +1,48 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCanceled } from 'services/thunks/session';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionCanceledPendingListener = () => {
startAppListening({
actionCreator: sessionCanceled.pending,
effect: (action, { getState, dispatch }) => {
//
},
});
};
export const addSessionCanceledFulfilledListener = () => {
startAppListening({
actionCreator: sessionCanceled.fulfilled,
effect: (action, { getState, dispatch }) => {
const { sessionId } = action.meta.arg;
moduleLog.debug(
{ data: { sessionId } },
`Session canceled (${sessionId})`
);
},
});
};
export const addSessionCanceledRejectedListener = () => {
startAppListening({
actionCreator: sessionCanceled.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
const { arg, error } = action.payload;
moduleLog.error(
{
data: {
arg,
error: serializeError(error),
},
},
`Problem canceling session`
);
}
},
});
};

View File

@ -0,0 +1,45 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCreated } from 'services/thunks/session';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionCreatedPendingListener = () => {
startAppListening({
actionCreator: sessionCreated.pending,
effect: (action, { getState, dispatch }) => {
//
},
});
};
export const addSessionCreatedFulfilledListener = () => {
startAppListening({
actionCreator: sessionCreated.fulfilled,
effect: (action, { getState, dispatch }) => {
const session = action.payload;
moduleLog.debug({ data: { session } }, `Session created (${session.id})`);
},
});
};
export const addSessionCreatedRejectedListener = () => {
startAppListening({
actionCreator: sessionCreated.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
const { arg, error } = action.payload;
moduleLog.error(
{
data: {
arg,
error: serializeError(error),
},
},
`Problem creating session`
);
}
},
});
};

View File

@ -0,0 +1,48 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionInvoked } from 'services/thunks/session';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionInvokedPendingListener = () => {
startAppListening({
actionCreator: sessionInvoked.pending,
effect: (action, { getState, dispatch }) => {
//
},
});
};
export const addSessionInvokedFulfilledListener = () => {
startAppListening({
actionCreator: sessionInvoked.fulfilled,
effect: (action, { getState, dispatch }) => {
const { sessionId } = action.meta.arg;
moduleLog.debug(
{ data: { sessionId } },
`Session invoked (${sessionId})`
);
},
});
};
export const addSessionInvokedRejectedListener = () => {
startAppListening({
actionCreator: sessionInvoked.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
const { arg, error } = action.payload;
moduleLog.error(
{
data: {
arg,
error: serializeError(error),
},
},
`Problem invoking session`
);
}
},
});
};

View File

@ -0,0 +1,22 @@
import { startAppListening } from '..';
import { sessionInvoked } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionReadyToInvokeListener = () => {
startAppListening({
actionCreator: sessionReadyToInvoke,
effect: (action, { getState, dispatch }) => {
const { sessionId } = getState().system;
if (sessionId) {
moduleLog.debug(
{ sessionId },
`Session ready to invoke (${sessionId})})`
);
dispatch(sessionInvoked({ sessionId }));
}
},
});
};

View File

@ -0,0 +1,38 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketConnectedEventListener = () => {
startAppListening({
actionCreator: socketConnected,
effect: (action, { dispatch, getState }) => {
const { timestamp } = action.payload;
moduleLog.debug({ timestamp }, 'Connected');
const { models, nodes, config, images } = getState();
const { disabledTabs } = config;
if (!images.ids.length) {
dispatch(receivedPageOfImages());
}
if (!models.ids.length) {
dispatch(receivedModels());
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
},
});
};

View File

@ -0,0 +1,19 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import {
socketDisconnected,
appSocketDisconnected,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketDisconnectedEventListener = () => {
startAppListening({
actionCreator: socketDisconnected,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(action.payload, 'Disconnected');
// pass along the socket event as an application action
dispatch(appSocketDisconnected(action.payload));
},
});
};

View File

@ -0,0 +1,34 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import {
appSocketGeneratorProgress,
socketGeneratorProgress,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addGeneratorProgressEventListener = () => {
startAppListening({
actionCreator: socketGeneratorProgress,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
action.payload.data.graph_execution_state_id
) {
moduleLog.trace(
action.payload,
'Ignored generator progress for canceled session'
);
return;
}
moduleLog.trace(
action.payload,
`Generator progress (${action.payload.data.node.type})`
);
// pass along the socket event as an application action
dispatch(appSocketGeneratorProgress(action.payload));
},
});
};

View File

@ -0,0 +1,22 @@
import { log } from 'app/logging/useLogger';
import {
appSocketGraphExecutionStateComplete,
socketGraphExecutionStateComplete,
} from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addGraphExecutionStateCompleteEventListener = () => {
startAppListening({
actionCreator: socketGraphExecutionStateComplete,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Session invocation complete (${action.payload.data.graph_execution_state_id})`
);
// pass along the socket event as an application action
dispatch(appSocketGraphExecutionStateComplete(action.payload));
},
});
};

View File

@ -0,0 +1,67 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import {
appSocketInvocationComplete,
socketInvocationComplete,
} from 'services/events/actions';
import { imageMetadataReceived } from 'services/thunks/image';
import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image'];
export const addInvocationCompleteEventListener = () => {
startAppListening({
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState, take }) => {
moduleLog.debug(
action.payload,
`Invocation complete (${action.payload.data.node.type})`
);
const sessionId = action.payload.data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
}
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_origin } = result.image;
// Get its metadata
dispatch(
imageMetadataReceived({
imageName: image_name,
imageOrigin: image_origin,
})
);
const [{ payload: imageDTO }] = await take(
imageMetadataReceived.fulfilled.match
);
// Handle canvas image
if (
graph_execution_state_id ===
getState().canvas.layerState.stagingArea.sessionId
) {
dispatch(addImageToStagingArea(imageDTO));
}
dispatch(progressImageSet(null));
}
// pass along the socket event as an application action
dispatch(appSocketInvocationComplete(action.payload));
},
});
};

View File

@ -0,0 +1,21 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import {
appSocketInvocationError,
socketInvocationError,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationError,
effect: (action, { dispatch, getState }) => {
moduleLog.error(
action.payload,
`Invocation error (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationError(action.payload));
},
});
};

View File

@ -0,0 +1,32 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import {
appSocketInvocationStarted,
socketInvocationStarted,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationStartedEventListener = () => {
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
action.payload.data.graph_execution_state_id
) {
moduleLog.trace(
action.payload,
'Ignored invocation started for canceled session'
);
return;
}
moduleLog.debug(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationStarted(action.payload));
},
});
};

View File

@ -0,0 +1,18 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { appSocketSubscribed, socketSubscribed } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketSubscribedEventListener = () => {
startAppListening({
actionCreator: socketSubscribed,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Subscribed (${action.payload.sessionId}))`
);
dispatch(appSocketSubscribed(action.payload));
},
});
};

View File

@ -0,0 +1,21 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import {
appSocketUnsubscribed,
socketUnsubscribed,
} from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketUnsubscribedEventListener = () => {
startAppListening({
actionCreator: socketUnsubscribed,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Unsubscribed (${action.payload.sessionId})`
);
dispatch(appSocketUnsubscribed(action.payload));
},
});
};

View File

@ -0,0 +1,54 @@
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { imageUpdated } from 'services/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice';
import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvas' });
export const addStagingAreaImageSavedListener = () => {
startAppListening({
actionCreator: stagingAreaImageSaved,
effect: async (action, { dispatch, getState, take }) => {
const { image_name, image_origin } = action.payload;
dispatch(
imageUpdated({
imageName: image_name,
imageOrigin: image_origin,
requestBody: {
is_intermediate: false,
},
})
);
const [imageUpdatedAction] = await take(
(action) =>
(imageUpdated.fulfilled.match(action) ||
imageUpdated.rejected.match(action)) &&
action.meta.arg.imageName === image_name
);
if (imageUpdated.rejected.match(imageUpdatedAction)) {
moduleLog.error(
{ data: { arg: imageUpdatedAction.meta.arg } },
'Image saving failed'
);
dispatch(
addToast({
title: 'Image Saving Failed',
description: imageUpdatedAction.error.message,
status: 'error',
})
);
return;
}
if (imageUpdated.fulfilled.match(imageUpdatedAction)) {
dispatch(imageUpserted(imageUpdatedAction.payload));
dispatch(addToast({ title: 'Image Saved', status: 'success' }));
}
},
});
};

View File

@ -1,9 +1,9 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session'; import { sessionCreated } from 'services/thunks/session';
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image'; import { imageUpdated, imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api'; import { Graph } from 'services/api';
import { import {
@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
/** /**
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. * This listener is responsible invoking the canvas. This involves a number of steps:
* It is also responsible for uploading the base and mask layers to the server. *
* 1. Generate image blobs from the canvas layers
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
* 3. Build the canvas graph
* 4. Create the session with the graph
* 5. Upload the init image if necessary
* 6. Upload the mask image if necessary
* 7. Update the init and mask images with the session ID
* 8. Initialize the staging area if not yet initialized
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
*/ */
export const addUserInvokedCanvasListener = () => { export const addUserInvokedCanvasListener = () => {
startAppListening({ startAppListening({
@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => {
const { rangeNode, iterateNode, baseNode, edges } = graphComponents; const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
// Upload the base layer, to be used as init image // Assemble! Note that this graph *does not have the init or mask image set yet!*
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
// Upload the mask layer image
const maskFilename = `${uuidv4()}.png`;
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = { const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode, [rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode, [iterateNode.id]: iterateNode,
@ -136,15 +90,92 @@ export const addUserInvokedCanvasListener = () => {
const graph = { nodes, edges }; const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph)); dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
// Actually create the session moduleLog.debug({ data: graph }, 'Canvas graph built');
// If we are generating img2img or inpaint, we need to upload the init images
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
imageCategory: 'general',
isIntermediate: true,
})
);
// Wait for the image to be uploaded
const [{ payload: baseImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
// Update the base node with the image name and type
baseNode.image = {
image_name: baseImageDTO.image_name,
image_origin: baseImageDTO.image_origin,
};
}
// For inpaint, we also need to upload the mask layer
if (baseNode.type === 'inpaint') {
const maskFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
imageCategory: 'mask',
isIntermediate: true,
})
);
// Wait for the mask to be uploaded
const [{ payload: maskImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
// Update the base node with the image name and type
baseNode.mask = {
image_name: maskImageDTO.image_name,
image_origin: maskImageDTO.image_origin,
};
}
// Create the session and wait for response
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
const sessionId = sessionCreatedAction.payload.id;
// Wait for the session to be invoked (this is just the HTTP request to start processing) // Associate the init image with the session, now that we have the session ID
const [{ meta }] = await take(sessionInvoked.fulfilled.match); if (
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
baseNode.image
) {
dispatch(
imageUpdated({
imageName: baseNode.image.image_name,
imageOrigin: baseNode.image.image_origin,
requestBody: { session_id: sessionId },
})
);
}
const { sessionId } = meta.arg; // Associate the mask image with the session, now that we have the session ID
if (baseNode.type === 'inpaint' && baseNode.mask) {
dispatch(
imageUpdated({
imageName: baseNode.mask.image_name,
imageOrigin: baseNode.mask.image_origin,
requestBody: { session_id: sessionId },
})
);
}
if (!state.canvas.layerState.stagingArea.boundingBox) { if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch( dispatch(
@ -158,7 +189,11 @@ export const addUserInvokedCanvasListener = () => {
); );
} }
// Flag the session with the canvas session ID
dispatch(canvasSessionIdChanged(sessionId)); dispatch(canvasSessionIdChanged(sessionId));
// We are ready to invoke the session!
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img', userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildImageToImageGraph(state); const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph)); dispatch(imageToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Image to Image graph built'); moduleLog.debug({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions'; import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes', userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildNodesGraph(state); const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph)); dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built'); moduleLog.debug({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img', userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildTextToImageGraph(state); const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph)); dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built');
moduleLog.debug({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -10,12 +10,12 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import resultsReducer from 'features/gallery/store/resultsSlice'; import imagesReducer from 'features/gallery/store/imagesSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
@ -40,12 +40,12 @@ const allReducers = {
models: modelsReducer, models: modelsReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
results: resultsReducer,
system: systemReducer, system: systemReducer,
config: configReducer, config: configReducer,
ui: uiReducer, ui: uiReducer,
uploads: uploadsReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
images: imagesReducer,
// session: sessionReducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);
@ -63,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system', 'system',
'ui', 'ui',
// 'hotkeys', // 'hotkeys',
// 'results',
// 'uploads',
// 'config', // 'config',
]; ];

View File

@ -1,316 +1,82 @@
/**
* Types for images, the things they are made of, and the things
* they make up.
*
* Generated images are txt2img and img2img images. They may have
* had additional postprocessing done on them when they were first
* generated.
*
* Postprocessed images are images which were not generated here
* but only postprocessed by the app. They only get postprocessing
* metadata and have a different image type, e.g. 'esrgan' or
* 'gfpgan'.
*/
import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ImageType } from 'services/api';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
/** // These are old types from the model management UI
* TODO:
* Once an image has been generated, if it is postprocessed again,
* additional postprocessing steps are added to its postprocessing
* array.
*
* TODO: Better documentation of types.
*/
export type PromptItem = { // export type ModelStatus = 'active' | 'cached' | 'not loaded';
prompt: string;
weight: number;
};
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type // export type Model = {
export type Prompt = Array<PromptItem> | string; // status: ModelStatus;
// description: string;
export type SeedWeightPair = { // weights: string;
seed: number; // config?: string;
weight: number; // vae?: string;
}; // width?: number;
// height?: number;
export type SeedWeights = Array<SeedWeightPair>; // default?: boolean;
// format?: string;
// All generated images contain these metadata.
export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler:
| 'ddim'
| 'ddpm'
| 'deis'
| 'lms'
| 'pndm'
| 'heun'
| 'heun_k'
| 'euler'
| 'euler_k'
| 'euler_a'
| 'kdpm_2'
| 'kdpm_2_a'
| 'dpmpp_2s'
| 'dpmpp_2m'
| 'dpmpp_2m_k'
| 'unipc';
prompt: Prompt;
seed: number;
variations: SeedWeights;
steps: number;
cfg_scale: number;
width: number;
height: number;
seamless: boolean;
hires_fix: boolean;
extra: null | Record<string, never>; // Pending development of RFC #266
};
// txt2img and img2img images have some unique attributes.
export type Txt2ImgMetadata = CommonGeneratedImageMetadata & {
type: 'txt2img';
};
export type Img2ImgMetadata = CommonGeneratedImageMetadata & {
type: 'img2img';
orig_hash: string;
strength: number;
fit: boolean;
init_image_path: string;
mask_image_path?: string;
};
// Superset of generated image metadata types.
export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
// All post processed images contain these metadata.
export type CommonPostProcessedImageMetadata = {
orig_path: string;
orig_hash: string;
};
// esrgan and gfpgan images have some unique attributes.
export type ESRGANMetadata = CommonPostProcessedImageMetadata & {
type: 'esrgan';
scale: 2 | 4;
strength: number;
denoise_str: number;
};
export type FacetoolMetadata = CommonPostProcessedImageMetadata & {
type: 'gfpgan' | 'codeformer';
strength: number;
fidelity?: number;
};
// Superset of all postprocessed image metadata types..
export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
// Metadata includes the system config and image metadata.
// export type Metadata = SystemGenerationMetadata & {
// image: GeneratedImageMetadata | PostProcessedImageMetadata;
// }; // };
/** // export type DiffusersModel = {
* ResultImage // status: ModelStatus;
*/ // description: string;
// export ty`pe Image = { // repo_id?: string;
// path?: string;
// vae?: {
// repo_id?: string;
// path?: string;
// };
// format?: string;
// default?: boolean;
// };
// export type ModelList = Record<string, Model & DiffusersModel>;
// export type FoundModel = {
// name: string; // name: string;
// type: ImageType; // location: string;
// url: string;
// thumbnail: string;
// metadata: ImageResponseMetadata;
// }; // };
// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => { // export type InvokeModelConfigProps = {
// if ('url' in obj && 'thumbnail' in obj) { // name: string | undefined;
// return true; // description: string | undefined;
// } // config: string | undefined;
// weights: string | undefined;
// return false; // vae: string | undefined;
// width: number | undefined;
// height: number | undefined;
// default: boolean | undefined;
// format: string | undefined;
// }; // };
/** // export type InvokeDiffusersModelConfigProps = {
* Types related to the system status. // name: string | undefined;
*/ // description: string | undefined;
// repo_id: string | undefined;
// // This represents the processing status of the backend. // path: string | undefined;
// export type SystemStatus = { // default: boolean | undefined;
// isProcessing: boolean; // format: string | undefined;
// currentStep: number; // vae: {
// totalSteps: number; // repo_id: string | undefined;
// currentIteration: number; // path: string | undefined;
// totalIterations: number; // };
// currentStatus: string;
// currentStatusHasSteps: boolean;
// hasError: boolean;
// }; // };
// export type SystemGenerationMetadata = { // export type InvokeModelConversionProps = {
// model: string; // model_name: string;
// model_weights?: string; // save_location: string;
// model_id?: string; // custom_location: string | null;
// model_hash: string;
// app_id: string;
// app_version: string;
// }; // };
// export type SystemConfig = SystemGenerationMetadata & { // export type InvokeModelMergingProps = {
// model_list: ModelList; // models_to_merge: string[];
// infill_methods: string[]; // alpha: number;
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
// force: boolean;
// merged_model_name: string;
// model_merge_save_path: string | null;
// }; // };
export type ModelStatus = 'active' | 'cached' | 'not loaded';
export type Model = {
status: ModelStatus;
description: string;
weights: string;
config?: string;
vae?: string;
width?: number;
height?: number;
default?: boolean;
format?: string;
};
export type DiffusersModel = {
status: ModelStatus;
description: string;
repo_id?: string;
path?: string;
vae?: {
repo_id?: string;
path?: string;
};
format?: string;
default?: boolean;
};
export type ModelList = Record<string, Model & DiffusersModel>;
export type FoundModel = {
name: string;
location: string;
};
export type InvokeModelConfigProps = {
name: string | undefined;
description: string | undefined;
config: string | undefined;
weights: string | undefined;
vae: string | undefined;
width: number | undefined;
height: number | undefined;
default: boolean | undefined;
format: string | undefined;
};
export type InvokeDiffusersModelConfigProps = {
name: string | undefined;
description: string | undefined;
repo_id: string | undefined;
path: string | undefined;
default: boolean | undefined;
format: string | undefined;
vae: {
repo_id: string | undefined;
path: string | undefined;
};
};
export type InvokeModelConversionProps = {
model_name: string;
save_location: string;
custom_location: string | null;
};
export type InvokeModelMergingProps = {
models_to_merge: string[];
alpha: number;
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
force: boolean;
merged_model_name: string;
model_merge_save_path: string | null;
};
/**
* These types type data received from the server via socketio.
*/
export type ModelChangeResponse = {
model_name: string;
model_list: ModelList;
};
export type ModelConvertedResponse = {
converted_model_name: string;
model_list: ModelList;
};
export type ModelsMergedResponse = {
merged_models: string[];
merged_model_name: string;
model_list: ModelList;
};
export type ModelAddedResponse = {
new_model_name: string;
model_list: ModelList;
update: boolean;
};
export type ModelDeletedResponse = {
deleted_model_name: string;
model_list: ModelList;
};
export type FoundModelResponse = {
search_folder: string;
found_models: FoundModel[];
};
// export type SystemStatusResponse = SystemStatus;
// export type SystemConfigResponse = SystemConfig;
export type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
export type ImageUploadResponse = {
// image: Omit<Image, 'uuid' | 'metadata' | 'category'>;
url: string;
mtime: number;
width: number;
height: number;
thumbnail: string;
// bbox: [number, number, number, number];
};
export type ErrorResponse = {
message: string;
additionalData?: string;
};
export type ImageUrlResponse = {
url: string;
};
export type UploadOutpaintingMergeImagePayload = {
dataURL: string;
name: string;
};
/** /**
* A disable-able application feature * A disable-able application feature
*/ */
@ -322,7 +88,8 @@ export type AppFeature =
| 'githubLink' | 'githubLink'
| 'discordLink' | 'discordLink'
| 'bugLink' | 'bugLink'
| 'localization'; | 'localization'
| 'consoleLogging';
/** /**
* A disable-able Stable Diffusion feature * A disable-able Stable Diffusion feature
@ -351,6 +118,7 @@ export type AppConfig = {
disabledSDFeatures: SDFeature[]; disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
defaultModel?: string;
iterations: { iterations: {
initial: number; initial: number;
min: number; min: number;

View File

@ -21,9 +21,12 @@ import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react'; import { memo } from 'react';
export type ItemTooltips = { [key: string]: string };
type IAICustomSelectProps = { type IAICustomSelectProps = {
label?: string; label?: string;
items: string[]; items: string[];
itemTooltips?: ItemTooltips;
selectedItem: string; selectedItem: string;
setSelectedItem: (v: string | null | undefined) => void; setSelectedItem: (v: string | null | undefined) => void;
withCheckIcon?: boolean; withCheckIcon?: boolean;
@ -37,6 +40,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
const { const {
label, label,
items, items,
itemTooltips,
setSelectedItem, setSelectedItem,
selectedItem, selectedItem,
withCheckIcon, withCheckIcon,
@ -118,6 +122,13 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
> >
<OverlayScrollbarsComponent> <OverlayScrollbarsComponent>
{items.map((item, index) => ( {items.map((item, index) => (
<Tooltip
isDisabled={!itemTooltips}
key={`${item}${index}`}
label={itemTooltips?.[item]}
hasArrow
placement="right"
>
<ListItem <ListItem
sx={{ sx={{
bg: highlightedIndex === index ? 'base.700' : undefined, bg: highlightedIndex === index ? 'base.700' : undefined,
@ -160,6 +171,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
</Text> </Text>
)} )}
</ListItem> </ListItem>
</Tooltip>
))} ))}
</OverlayScrollbarsComponent> </OverlayScrollbarsComponent>
</List> </List>

View File

@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
type ImageUploadOverlayProps = { type ImageUploadOverlayProps = {
isDragAccept: boolean; isDragAccept: boolean;
isDragReject: boolean; isDragReject: boolean;
overlaySecondaryText: string;
setIsHandlingUpload: (isHandlingUpload: boolean) => void; setIsHandlingUpload: (isHandlingUpload: boolean) => void;
}; };
@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const { const {
isDragAccept, isDragAccept,
isDragReject: _isDragAccept, isDragReject: _isDragAccept,
overlaySecondaryText,
setIsHandlingUpload, setIsHandlingUpload,
} = props; } = props;
@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
}} }}
> >
{isDragAccept ? ( {isDragAccept ? (
<Heading size="lg">Upload Image{overlaySecondaryText}</Heading> <Heading size="lg">Drop to Upload</Heading>
) : ( ) : (
<> <>
<Heading size="lg">Invalid Upload</Heading> <Heading size="lg">Invalid Upload</Heading>

View File

@ -68,13 +68,13 @@ const ImageUploader = (props: ImageUploaderProps) => {
async (file: File) => { async (file: File) => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
imageType: 'uploads',
formData: { file }, formData: { file },
activeTabName, imageCategory: 'user',
isIntermediate: false,
}) })
); );
}, },
[dispatch, activeTabName] [dispatch]
); );
const onDrop = useCallback( const onDrop = useCallback(
@ -145,14 +145,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
}; };
}, [inputRef, open, setOpenUploaderFunction]); }, [inputRef, open, setOpenUploaderFunction]);
const overlaySecondaryText = useMemo(() => {
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
}
return '';
}, [t, activeTabName]);
return ( return (
<Box <Box
{...getRootProps({ style: {} })} {...getRootProps({ style: {} })}
@ -167,7 +159,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
<ImageUploadOverlay <ImageUploadOverlay
isDragAccept={isDragAccept} isDragAccept={isDragAccept}
isDragReject={isDragReject} isDragReject={isDragReject}
overlaySecondaryText={overlaySecondaryText}
setIsHandlingUpload={setIsHandlingUpload} setIsHandlingUpload={setIsHandlingUpload}
/> />
)} )}

View File

@ -1,119 +0,0 @@
/**
* PARTIAL ZOD IMPLEMENTATION
*
* doesn't work well bc like most validators, zod is not built to skip invalid values.
* it mostly works but just seems clearer and simpler to manually parse for now.
*
* in the future it would be really nice if we could use zod for some things:
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
*/
// import { z } from 'zod';
// const zMetadataStringField = z.string();
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
// const zMetadataIntegerField = z.number().int();
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
// const zMetadataFloatField = z.number();
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
// const zMetadataBooleanField = z.boolean();
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
// const zMetadataImageField = z.object({
// image_type: z.union([
// z.literal('results'),
// z.literal('uploads'),
// z.literal('intermediates'),
// ]),
// image_name: z.string().min(1),
// });
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
// const zMetadataLatentsField = z.object({
// latents_name: z.string().min(1),
// });
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
// /**
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
// */
// const zAnyMetadataField = z.any().transform((val, ctx) => {
// // Grab the field name from the path
// const fieldName = String(ctx.path[ctx.path.length - 1]);
// // `id` and `type` must be strings if they exist
// if (['id', 'type'].includes(fieldName)) {
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
// if (reservedStringPropertyResult.success) {
// return reservedStringPropertyResult.data;
// }
// return;
// }
// // Parse the rest of the fields, only returning the data if the parsing is successful
// const stringFieldResult = zMetadataStringField.safeParse(val);
// if (stringFieldResult.success) {
// return stringFieldResult.data;
// }
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
// if (integerFieldResult.success) {
// return integerFieldResult.data;
// }
// const floatFieldResult = zMetadataFloatField.safeParse(val);
// if (floatFieldResult.success) {
// return floatFieldResult.data;
// }
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
// if (booleanFieldResult.success) {
// return booleanFieldResult.data;
// }
// const imageFieldResult = zMetadataImageField.safeParse(val);
// if (imageFieldResult.success) {
// return imageFieldResult.data;
// }
// const latentsFieldResult = zMetadataImageField.safeParse(val);
// if (latentsFieldResult.success) {
// return latentsFieldResult.data;
// }
// });
// /**
// * The node metadata schema.
// */
// const zNodeMetadata = z.object({
// session_id: z.string().min(1).optional(),
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
// });
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
// const zMetadata = z.object({
// invokeai: zNodeMetadata.optional(),
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
// });
// export type Metadata = z.infer<typeof zMetadata>;
// export const parseMetadata = (
// metadata: Record<string, any>
// ): Metadata | undefined => {
// const result = zMetadata.safeParse(metadata);
// if (!result.success) {
// console.log(result.error.issues);
// return;
// }
// return result.data;
// };
export default {};

View File

@ -1,334 +0,0 @@
import { forEach, size } from 'lodash-es';
import {
ImageField,
LatentsField,
ConditioningField,
UNetField,
ClipField,
VaeField,
} from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
const NUMBER_TYPESTRING = '[object Number]';
const BOOLEAN_TYPESTRING = '[object Boolean]';
const ARRAY_TYPESTRING = '[object Array]';
const isObject = (obj: unknown): obj is Record<string | number, any> =>
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
const isString = (obj: unknown): obj is string =>
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
const isNumber = (obj: unknown): obj is number =>
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
const isBoolean = (obj: unknown): obj is boolean =>
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
const isArray = (obj: unknown): obj is Array<any> =>
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
const parseImageField = (imageField: unknown): ImageField | undefined => {
// Must be an object
if (!isObject(imageField)) {
return;
}
// An ImageField must have both `image_name` and `image_type`
if (!('image_name' in imageField && 'image_type' in imageField)) {
return;
}
// An ImageField's `image_type` must be one of the allowed values
if (
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
) {
return;
}
// An ImageField's `image_name` must be a string
if (typeof imageField.image_name !== 'string') {
return;
}
// Build a valid ImageField
return {
image_type: imageField.image_type,
image_name: imageField.image_name,
};
};
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
// Must be an object
if (!isObject(latentsField)) {
return;
}
// A LatentsField must have a `latents_name`
if (!('latents_name' in latentsField)) {
return;
}
// A LatentsField's `latents_name` must be a string
if (typeof latentsField.latents_name !== 'string') {
return;
}
// Build a valid LatentsField
return {
latents_name: latentsField.latents_name,
};
};
const parseConditioningField = (
conditioningField: unknown
): ConditioningField | undefined => {
// Must be an object
if (!isObject(conditioningField)) {
return;
}
// A ConditioningField must have a `conditioning_name`
if (!('conditioning_name' in conditioningField)) {
return;
}
// A ConditioningField's `conditioning_name` must be a string
if (typeof conditioningField.conditioning_name !== 'string') {
return;
}
// Build a valid ConditioningField
return {
conditioning_name: conditioningField.conditioning_name,
};
};
const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => {
// Must be an object
if (!isObject(modelInfo)) {
return;
}
if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) {
return;
}
if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) {
return;
}
if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) {
return;
}
return {
model_name: modelInfo.model_name,
model_type: modelInfo.model_type,
submodel: modelInfo.submodel,
};
};
const parseUNetField = (unetField: unknown): UNetField | undefined => {
// Must be an object
if (!isObject(unetField)) {
return;
}
if (!('unet' in unetField && 'scheduler' in unetField)) {
return;
}
const unet = _parseModelInfo(unetField.unet);
const scheduler = _parseModelInfo(unetField.scheduler);
if (!(unet && scheduler)) {
return;
}
// Build a valid UNetField
return {
unet: unet,
scheduler: scheduler,
};
};
const parseClipField = (clipField: unknown): ClipField | undefined => {
// Must be an object
if (!isObject(clipField)) {
return;
}
if (!('tokenizer' in clipField && 'text_encoder' in clipField)) {
return;
}
const tokenizer = _parseModelInfo(clipField.tokenizer);
const text_encoder = _parseModelInfo(clipField.text_encoder);
if (!(tokenizer && text_encoder)) {
return;
}
// Build a valid ClipField
return {
tokenizer: tokenizer,
text_encoder: text_encoder,
};
};
const parseVaeField = (vaeField: unknown): VaeField | undefined => {
// Must be an object
if (!isObject(vaeField)) {
return;
}
if (!('vae' in vaeField)) {
return;
}
const vae = _parseModelInfo(vaeField.vae);
if (!vae) {
return;
}
// Build a valid VaeField
return {
vae: vae,
};
};
type NodeMetadata = {
[key: string]:
| string
| number
| boolean
| ImageField
| LatentsField
| ConditioningField
| UNetField
| ClipField
| VaeField;
};
type InvokeAIMetadata = {
session_id?: string;
node?: NodeMetadata;
};
export const parseNodeMetadata = (
nodeMetadata: Record<string | number, any>
): NodeMetadata | undefined => {
if (!isObject(nodeMetadata)) {
return;
}
const parsed: NodeMetadata = {};
forEach(nodeMetadata, (nodeItem, nodeKey) => {
// `id` and `type` must be strings if they are present
if (['id', 'type'].includes(nodeKey)) {
if (isString(nodeItem)) {
parsed[nodeKey] = nodeItem;
}
return;
}
// valid object types are:
// ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
if (imageField) {
parsed[nodeKey] = imageField;
}
return;
}
if ('latents_name' in nodeItem) {
const latentsField = parseLatentsField(nodeItem);
if (latentsField) {
parsed[nodeKey] = latentsField;
}
return;
}
if ('conditioning_name' in nodeItem) {
const conditioningField = parseConditioningField(nodeItem);
if (conditioningField) {
parsed[nodeKey] = conditioningField;
}
return;
}
if ('unet' in nodeItem && 'scheduler' in nodeItem) {
const unetField = parseUNetField(nodeItem);
if (unetField) {
parsed[nodeKey] = unetField;
}
}
if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) {
const clipField = parseClipField(nodeItem);
if (clipField) {
parsed[nodeKey] = clipField;
}
}
if ('vae' in nodeItem) {
const vaeField = parseVaeField(nodeItem);
if (vaeField) {
parsed[nodeKey] = vaeField;
}
}
}
// otherwise we accept any string, number or boolean
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
parsed[nodeKey] = nodeItem;
return;
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};
export const parseInvokeAIMetadata = (
metadata: Record<string | number, any> | undefined
): InvokeAIMetadata | undefined => {
if (metadata === undefined) {
return;
}
if (!isObject(metadata)) {
return;
}
const parsed: InvokeAIMetadata = {};
forEach(metadata, (item, key) => {
if (key === 'session_id' && isString(item)) {
parsed['session_id'] = item;
}
if (key === 'node' && isObject(item)) {
const nodeMetadata = parseNodeMetadata(item);
if (nodeMetadata) {
parsed['node'] = nodeMetadata;
}
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};

View File

@ -1,18 +1,24 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl'; import { systemSelector } from 'features/system/store/systemSelectors';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image'; import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { Image as KonvaImage } from 'react-konva'; import { Image as KonvaImage } from 'react-konva';
import { canvasSelector } from '../store/canvasSelectors';
const selector = createSelector( const selector = createSelector(
[(state: RootState) => state.gallery], [systemSelector, canvasSelector],
(gallery: GalleryState) => { (system, canvas) => {
return gallery.intermediateImage ? gallery.intermediateImage : null; const { progressImage, sessionId } = system;
const { sessionId: canvasSessionId, boundingBox } =
canvas.layerState.stagingArea;
return {
boundingBox,
progressImage: sessionId === canvasSessionId ? progressImage : undefined,
};
}, },
{ {
memoizeOptions: { memoizeOptions: {
@ -25,33 +31,34 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => { const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props; const { ...rest } = props;
const intermediateImage = useAppSelector(selector); const { progressImage, boundingBox } = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] = const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null); useState<HTMLImageElement | null>(null);
useEffect(() => { useEffect(() => {
if (!intermediateImage) return; if (!progressImage) {
return;
}
const tempImage = new Image(); const tempImage = new Image();
tempImage.onload = () => { tempImage.onload = () => {
setLoadedImageElement(tempImage); setLoadedImageElement(tempImage);
}; };
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
if (!intermediateImage?.boundingBox) return null; tempImage.src = progressImage.dataURL;
}, [progressImage]);
const { if (!(progressImage && boundingBox)) {
boundingBox: { x, y, width, height }, return null;
} = intermediateImage; }
return loadedImageElement ? ( return loadedImageElement ? (
<KonvaImage <KonvaImage
x={x} x={boundingBox.x}
y={y} y={boundingBox.y}
width={width} width={boundingBox.width}
height={height} height={boundingBox.height}
image={loadedImageElement} image={loadedImageElement}
listening={false} listening={false}
{...rest} {...rest}

View File

@ -62,7 +62,7 @@ const IAICanvasStagingArea = (props: Props) => {
<Group {...rest}> <Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && ( {shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage <IAICanvasImage
url={getUrl(currentStagingAreaImage.image.image_url)} url={getUrl(currentStagingAreaImage.image.image_url) ?? ''}
x={x} x={x}
y={y} y={y}
/> />

View File

@ -1,6 +1,5 @@
import { ButtonGroup, Flex } from '@chakra-ui/react'; import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
// import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
@ -26,13 +25,14 @@ import {
FaPlus, FaPlus,
FaSave, FaSave,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { stagingAreaImageSaved } from '../store/actions';
const selector = createSelector( const selector = createSelector(
[canvasSelector], [canvasSelector],
(canvas) => { (canvas) => {
const { const {
layerState: { layerState: {
stagingArea: { images, selectedImageIndex }, stagingArea: { images, selectedImageIndex, sessionId },
}, },
shouldShowStagingOutline, shouldShowStagingOutline,
shouldShowStagingImage, shouldShowStagingImage,
@ -45,6 +45,7 @@ const selector = createSelector(
isOnLastImage: selectedImageIndex === images.length - 1, isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage, shouldShowStagingImage,
shouldShowStagingOutline, shouldShowStagingOutline,
sessionId,
}; };
}, },
{ {
@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => {
isOnLastImage, isOnLastImage,
currentStagingAreaImage, currentStagingAreaImage,
shouldShowStagingImage, shouldShowStagingImage,
sessionId,
} = useAppSelector(selector); } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => {
} }
); );
const handlePrevImage = () => dispatch(prevStagingAreaImage()); const handlePrevImage = useCallback(
const handleNextImage = () => dispatch(nextStagingAreaImage()); () => dispatch(prevStagingAreaImage()),
const handleAccept = () => dispatch(commitStagingAreaImage()); [dispatch]
);
const handleNextImage = useCallback(
() => dispatch(nextStagingAreaImage()),
[dispatch]
);
const handleAccept = useCallback(
() => dispatch(commitStagingAreaImage(sessionId)),
[dispatch, sessionId]
);
if (!currentStagingAreaImage) return null; if (!currentStagingAreaImage) return null;
@ -157,19 +170,15 @@ const IAICanvasStagingAreaToolbar = () => {
} }
colorScheme="accent" colorScheme="accent"
/> />
{/* <IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.saveToGallery')} tooltip={t('unifiedCanvas.saveToGallery')}
aria-label={t('unifiedCanvas.saveToGallery')} aria-label={t('unifiedCanvas.saveToGallery')}
icon={<FaSave />} icon={<FaSave />}
onClick={() => onClick={() =>
dispatch( dispatch(stagingAreaImageSaved(currentStagingAreaImage.image))
saveStagingAreaImageToGallery(
currentStagingAreaImage.image.image_url
)
)
} }
colorScheme="accent" colorScheme="accent"
/> */} />
<IAIIconButton <IAIIconButton
tooltip={t('unifiedCanvas.discardAll')} tooltip={t('unifiedCanvas.discardAll')}
aria-label={t('unifiedCanvas.discardAll')} aria-label={t('unifiedCanvas.discardAll')}

View File

@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery'); export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@ -11,3 +12,7 @@ export const canvasDownloadedAsImage = createAction(
); );
export const canvasMerged = createAction('canvas/canvasMerged'); export const canvasMerged = createAction('canvas/canvasMerged');
export const stagingAreaImageSaved = createAction<ImageDTO>(
'canvas/stagingAreaImageSaved'
);

View File

@ -29,6 +29,7 @@ import {
isCanvasMaskLine, isCanvasMaskLine,
} from './canvasTypes'; } from './canvasTypes';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { sessionCanceled } from 'services/thunks/session';
export const initialLayerState: CanvasLayerState = { export const initialLayerState: CanvasLayerState = {
objects: [], objects: [],
@ -696,7 +697,10 @@ export const canvasSlice = createSlice({
0 0
); );
}, },
commitStagingAreaImage: (state) => { commitStagingAreaImage: (
state,
action: PayloadAction<string | undefined>
) => {
if (!state.layerState.stagingArea.images.length) { if (!state.layerState.stagingArea.images.length) {
return; return;
} }
@ -841,6 +845,13 @@ export const canvasSlice = createSlice({
state.isTransformingBoundingBox = false; state.isTransformingBoundingBox = false;
}, },
}, },
extraReducers: (builder) => {
builder.addCase(sessionCanceled.pending, (state) => {
if (!state.layerState.stagingArea.images.length) {
state.layerState.stagingArea = initialLayerState.stagingArea;
}
});
},
}); });
export const { export const {

View File

@ -9,7 +9,8 @@ import { IRect } from 'konva/lib/types';
*/ */
const createMaskStage = async ( const createMaskStage = async (
lines: CanvasMaskLine[], lines: CanvasMaskLine[],
boundingBox: IRect boundingBox: IRect,
shouldInvertMask: boolean
): Promise<Konva.Stage> => { ): Promise<Konva.Stage> => {
// create an offscreen canvas and add the mask to it // create an offscreen canvas and add the mask to it
const { width, height } = boundingBox; const { width, height } = boundingBox;
@ -29,7 +30,7 @@ const createMaskStage = async (
baseLayer.add( baseLayer.add(
new Konva.Rect({ new Konva.Rect({
...boundingBox, ...boundingBox,
fill: 'white', fill: shouldInvertMask ? 'black' : 'white',
}) })
); );
@ -37,7 +38,7 @@ const createMaskStage = async (
maskLayer.add( maskLayer.add(
new Konva.Line({ new Konva.Line({
points: line.points, points: line.points,
stroke: 'black', stroke: shouldInvertMask ? 'white' : 'black',
strokeWidth: line.strokeWidth * 2, strokeWidth: line.strokeWidth * 2,
tension: 0, tension: 0,
lineCap: 'round', lineCap: 'round',

View File

@ -25,6 +25,7 @@ export const getCanvasData = async (state: RootState) => {
boundingBoxCoordinates, boundingBoxCoordinates,
boundingBoxDimensions, boundingBoxDimensions,
isMaskEnabled, isMaskEnabled,
shouldPreserveMaskedArea,
} = state.canvas; } = state.canvas;
const boundingBox = { const boundingBox = {
@ -58,7 +59,8 @@ export const getCanvasData = async (state: RootState) => {
// For the mask layer, use the normal boundingBox // For the mask layer, use the normal boundingBox
const maskStage = await createMaskStage( const maskStage = await createMaskStage(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
boundingBox boundingBox,
shouldPreserveMaskedArea
); );
const maskBlob = await konvaNodeToBlob(maskStage, boundingBox); const maskBlob = await konvaNodeToBlob(maskStage, boundingBox);
const maskImageData = await konvaNodeToImageData(maskStage, boundingBox); const maskImageData = await konvaNodeToImageData(maskStage, boundingBox);

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual, isString } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { import {
ButtonGroup, ButtonGroup,
@ -25,8 +25,8 @@ import {
} from 'features/ui/store/uiSelectors'; } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab, setActiveTab,
setShouldHidePreview,
setShouldShowImageDetails, setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice'; } from 'features/ui/store/uiSlice';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -37,23 +37,19 @@ import {
FaDownload, FaDownload,
FaExpand, FaExpand,
FaExpandArrowsAlt, FaExpandArrowsAlt,
FaEye,
FaEyeSlash,
FaGrinStars, FaGrinStars,
FaHourglassHalf,
FaQuoteRight, FaQuoteRight,
FaSeedling, FaSeedling,
FaShare, FaShare,
FaShareAlt, FaShareAlt,
FaTrash,
FaWrench,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors'; import { gallerySelector } from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useParameters } from 'features/parameters/hooks/useParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { import {
requestedImageDeletion, requestedImageDeletion,
@ -62,7 +58,6 @@ import {
} from '../store/actions'; } from '../store/actions';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings'; import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings'; import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import { allParametersSet } from 'features/parameters/store/generationSlice';
import DeleteImageButton from './ImageActionButtons/DeleteImageButton'; import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -90,7 +85,11 @@ const currentImageButtonsSelector = createSelector(
const { isLightboxOpen } = lightbox; const { isLightboxOpen } = lightbox;
const { shouldShowImageDetails, shouldHidePreview } = ui; const {
shouldShowImageDetails,
shouldHidePreview,
shouldShowProgressInViewer,
} = ui;
const { selectedImage } = gallery; const { selectedImage } = gallery;
@ -112,6 +111,7 @@ const currentImageButtonsSelector = createSelector(
seed: selectedImage?.metadata?.seed, seed: selectedImage?.metadata?.seed,
prompt: selectedImage?.metadata?.positive_conditioning, prompt: selectedImage?.metadata?.positive_conditioning,
negativePrompt: selectedImage?.metadata?.negative_conditioning, negativePrompt: selectedImage?.metadata?.negative_conditioning,
shouldShowProgressInViewer,
}; };
}, },
{ {
@ -145,6 +145,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
image, image,
canDeleteImage, canDeleteImage,
shouldConfirmOnDelete, shouldConfirmOnDelete,
shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector); } = useAppSelector(currentImageButtonsSelector);
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
@ -163,7 +164,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();
const { recallPrompt, recallSeed, recallAllParameters } = useParameters(); const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
// const handleCopyImage = useCallback(async () => { // const handleCopyImage = useCallback(async () => {
// if (!image?.url) { // if (!image?.url) {
@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}); });
}, [toaster, shouldTransformUrls, getUrl, t, image]); }, [toaster, shouldTransformUrls, getUrl, t, image]);
const handlePreviewVisibility = useCallback(() => {
dispatch(setShouldHidePreview(!shouldHidePreview));
}, [dispatch, shouldHidePreview]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(image); recallAllParameters(image);
}, [image, recallAllParameters]); }, [image, recallAllParameters]);
@ -252,11 +250,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [image]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallPrompt( recallBothPrompts(
image?.metadata?.positive_conditioning, image?.metadata?.positive_conditioning,
image?.metadata?.negative_conditioning image?.metadata?.negative_conditioning
); );
}, [image, recallPrompt]); }, [image, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [image]);
@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
} }
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]); }, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
const handleClickProgressImagesToggle = useCallback(() => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]);
useHotkeys('delete', handleInitiateDelete, [ useHotkeys('delete', handleInitiateDelete, [
image, image,
shouldConfirmOnDelete, shouldConfirmOnDelete,
@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
isDisabled={!image}
aria-label={`${t('parameters.sendTo')}...`} aria-label={`${t('parameters.sendTo')}...`}
tooltip={`${t('parameters.sendTo')}...`}
isDisabled={!image}
icon={<FaShareAlt />} icon={<FaShareAlt />}
/> />
} }
@ -458,28 +461,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')} {t('parameters.copyImageToLink')}
</IAIButton> </IAIButton>
<Link download={true} href={getUrl(image?.image_url ?? '')}> <Link
download={true}
href={getUrl(image?.image_url ?? '')}
target="_blank"
>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%"> <IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')} {t('parameters.downloadImage')}
</IAIButton> </IAIButton>
</Link> </Link>
</Flex> </Flex>
</IAIPopover> </IAIPopover>
{/* <IAIIconButton
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
tooltip={
!shouldHidePreview
? t('parameters.hidePreview')
: t('parameters.showPreview')
}
aria-label={
!shouldHidePreview
? t('parameters.hidePreview')
: t('parameters.showPreview')
}
isChecked={shouldHidePreview}
onClick={handlePreviewVisibility}
/> */}
{isLightboxEnabled && ( {isLightboxEnabled && (
<IAIIconButton <IAIIconButton
icon={<FaExpand />} icon={<FaExpand />}
@ -604,6 +596,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/> />
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}>
<IAIIconButton
aria-label={t('settings.displayInProgress')}
tooltip={t('settings.displayInProgress')}
icon={<FaHourglassHalf />}
isChecked={shouldShowProgressInViewer}
onClick={handleClickProgressImagesToggle}
/>
</ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true}>
<DeleteImageButton image={image} /> <DeleteImageButton image={image} />
</ButtonGroup> </ButtonGroup>

View File

@ -62,7 +62,6 @@ const CurrentImagePreview = () => {
return; return;
} }
e.dataTransfer.setData('invokeai/imageName', image.image_name); e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move'; e.dataTransfer.effectAllowed = 'move';
}, },
[image] [image]

View File

@ -30,7 +30,7 @@ import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useParameters } from 'features/parameters/hooks/useParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { import {
requestedImageDeletion, requestedImageDeletion,
@ -114,8 +114,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useParameters(); useRecallParameters();
const handleMouseOver = () => setIsHovered(true); const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false); const handleMouseOut = () => setIsHovered(false);
@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleDragStart = useCallback( const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageName', image.image_name); e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move'; e.dataTransfer.effectAllowed = 'move';
}, },
[image] [image]
@ -155,11 +154,15 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Recall parameters handlers // Recall parameters handlers
const handleRecallPrompt = useCallback(() => { const handleRecallPrompt = useCallback(() => {
recallPrompt( recallBothPrompts(
image.metadata?.positive_conditioning, image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning image.metadata?.negative_conditioning
); );
}, [image, recallPrompt]); }, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => { const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed); recallSeed(image.metadata?.seed);

View File

@ -16,7 +16,6 @@ import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { import {
setCurrentCategory,
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
setGalleryImageObjectFit, setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages, setShouldAutoSwitchToNewImages,
@ -31,59 +30,46 @@ import {
memo, memo,
useCallback, useCallback,
useEffect, useEffect,
useMemo,
useRef, useRef,
useState, useState,
} from 'react'; } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
import { FaImage, FaUser, FaWrench } from 'react-icons/fa'; import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
import { MdPhotoLibrary } from 'react-icons/md'; import { MdPhotoLibrary } from 'react-icons/md';
import HoverableImage from './HoverableImage'; import HoverableImage from './HoverableImage';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { resultsAdapter } from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import GalleryProgressImage from './GalleryProgressImage';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { ImageDTO } from 'services/api'; import {
ASSETS_CATEGORIES,
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290; IMAGE_CATEGORIES,
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER'; imageCategoriesChanged,
selectImagesAll,
} from '../store/imagesSlice';
import { receivedPageOfImages } from 'services/thunks/image';
const categorySelector = createSelector( const categorySelector = createSelector(
[(state: RootState) => state], [(state: RootState) => state],
(state) => { (state) => {
const { results, uploads, system, gallery } = state; const { images } = state;
const { currentCategory } = gallery; const { categories } = images;
if (currentCategory === 'results') { const allImages = selectImagesAll(state);
const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = []; const filteredImages = allImages.filter((i) =>
categories.includes(i.image_category)
if (system.progressImage) { );
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
}
return { return {
images: tempImages.concat( images: filteredImages,
resultsAdapter.getSelectors().selectAll(results) isLoading: images.isLoading,
), areMoreImagesAvailable: filteredImages.length < images.total,
isLoading: results.isLoading, categories: images.categories,
areMoreImagesAvailable: results.page < results.pages - 1,
};
}
return {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -93,7 +79,6 @@ const mainSelector = createSelector(
[gallerySelector, uiSelector], [gallerySelector, uiSelector],
(gallery, ui) => { (gallery, ui) => {
const { const {
currentCategory,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
shouldAutoSwitchToNewImages, shouldAutoSwitchToNewImages,
@ -104,7 +89,6 @@ const mainSelector = createSelector(
const { shouldPinGallery } = ui; const { shouldPinGallery } = ui;
return { return {
currentCategory,
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
@ -120,7 +104,6 @@ const ImageGalleryContent = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const resizeObserverRef = useRef<HTMLDivElement>(null); const resizeObserverRef = useRef<HTMLDivElement>(null);
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const rootRef = useRef(null); const rootRef = useRef(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null); const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({ const [initialize, osInstance] = useOverlayScrollbars({
@ -137,7 +120,6 @@ const ImageGalleryContent = () => {
}); });
const { const {
currentCategory,
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
@ -146,18 +128,19 @@ const ImageGalleryContent = () => {
selectedImage, selectedImage,
} = useAppSelector(mainSelector); } = useAppSelector(mainSelector);
const { images, areMoreImagesAvailable, isLoading } = const { images, areMoreImagesAvailable, isLoading, categories } =
useAppSelector(categorySelector); useAppSelector(categorySelector);
const handleClickLoadMore = () => { const handleLoadMoreImages = useCallback(() => {
if (currentCategory === 'results') { dispatch(receivedPageOfImages());
dispatch(receivedResultImagesPage()); }, [dispatch]);
}
if (currentCategory === 'uploads') { const handleEndReached = useMemo(() => {
dispatch(receivedUploadImagesPage()); if (areMoreImagesAvailable && !isLoading) {
return handleLoadMoreImages;
} }
}; return undefined;
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
const handleChangeGalleryImageMinimumWidth = (v: number) => { const handleChangeGalleryImageMinimumWidth = (v: number) => {
dispatch(setGalleryImageMinimumWidth(v)); dispatch(setGalleryImageMinimumWidth(v));
@ -168,28 +151,6 @@ const ImageGalleryContent = () => {
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
}; };
useEffect(() => {
if (!resizeObserverRef.current) {
return;
}
const resizeObserver = new ResizeObserver(() => {
if (!resizeObserverRef.current) {
return;
}
if (
resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH
) {
setShouldShouldIconButtons(true);
return;
}
setShouldShouldIconButtons(false);
});
resizeObserver.observe(resizeObserverRef.current);
return () => resizeObserver.disconnect(); // clean up
}, []);
useEffect(() => { useEffect(() => {
const { current: root } = rootRef; const { current: root } = rootRef;
if (scroller && root) { if (scroller && root) {
@ -209,13 +170,13 @@ const ImageGalleryContent = () => {
} }
}, []); }, []);
const handleEndReached = useCallback(() => { const handleClickImagesCategory = useCallback(() => {
if (currentCategory === 'results') { dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(receivedResultImagesPage()); }, [dispatch]);
} else if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage()); const handleClickAssetsCategory = useCallback(() => {
} dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
}, [dispatch, currentCategory]); }, [dispatch]);
return ( return (
<Flex <Flex
@ -232,59 +193,31 @@ const ImageGalleryContent = () => {
alignItems="center" alignItems="center"
justifyContent="space-between" justifyContent="space-between"
> >
<ButtonGroup <ButtonGroup isAttached>
size="sm"
isAttached
w="max-content"
justifyContent="stretch"
>
{shouldShouldIconButtons ? (
<>
<IAIIconButton <IAIIconButton
aria-label={t('gallery.showGenerations')} tooltip={t('gallery.images')}
tooltip={t('gallery.showGenerations')} aria-label={t('gallery.images')}
isChecked={currentCategory === 'results'} onClick={handleClickImagesCategory}
role="radio" isChecked={categories === IMAGE_CATEGORIES}
size="sm"
icon={<FaImage />} icon={<FaImage />}
onClick={() => dispatch(setCurrentCategory('results'))}
/> />
<IAIIconButton <IAIIconButton
aria-label={t('gallery.showUploads')} tooltip={t('gallery.assets')}
tooltip={t('gallery.showUploads')} aria-label={t('gallery.assets')}
role="radio" onClick={handleClickAssetsCategory}
isChecked={currentCategory === 'uploads'} isChecked={categories === ASSETS_CATEGORIES}
icon={<FaUser />} size="sm"
onClick={() => dispatch(setCurrentCategory('uploads'))} icon={<FaServer />}
/> />
</>
) : (
<>
<IAIButton
size="sm"
isChecked={currentCategory === 'results'}
onClick={() => dispatch(setCurrentCategory('results'))}
flexGrow={1}
>
{t('gallery.generations')}
</IAIButton>
<IAIButton
size="sm"
isChecked={currentCategory === 'uploads'}
onClick={() => dispatch(setCurrentCategory('uploads'))}
flexGrow={1}
>
{t('gallery.uploads')}
</IAIButton>
</>
)}
</ButtonGroup> </ButtonGroup>
<Flex gap={2}> <Flex gap={2}>
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
size="sm" tooltip={t('gallery.gallerySettings')}
aria-label={t('gallery.gallerySettings')} aria-label={t('gallery.gallerySettings')}
size="sm"
icon={<FaWrench />} icon={<FaWrench />}
/> />
} }
@ -347,28 +280,17 @@ const ImageGalleryContent = () => {
data={images} data={images}
endReached={handleEndReached} endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)} scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, image) => { itemContent={(index, image) => (
const isSelected =
image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.image_name === image?.image_name;
return (
<Flex sx={{ pb: 2 }}> <Flex sx={{ pb: 2 }}>
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage
key={PROGRESS_IMAGE_PLACEHOLDER}
/>
) : (
<HoverableImage <HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`} key={`${image.image_name}-${image.thumbnail_url}`}
image={image} image={image}
isSelected={isSelected} isSelected={
selectedImage?.image_name === image?.image_name
}
/> />
)}
</Flex> </Flex>
); )}
}}
/> />
) : ( ) : (
<VirtuosoGrid <VirtuosoGrid
@ -380,27 +302,20 @@ const ImageGalleryContent = () => {
List: ListContainer, List: ListContainer,
}} }}
scrollerRef={setScroller} scrollerRef={setScroller}
itemContent={(index, image) => { itemContent={(index, image) => (
const isSelected =
image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.image_name === image?.image_name;
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
) : (
<HoverableImage <HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`} key={`${image.image_name}-${image.thumbnail_url}`}
image={image} image={image}
isSelected={isSelected} isSelected={
selectedImage?.image_name === image?.image_name
}
/> />
); )}
}}
/> />
)} )}
</Box> </Box>
<IAIButton <IAIButton
onClick={handleClickLoadMore} onClick={handleLoadMoreImages}
isDisabled={!areMoreImagesAvailable} isDisabled={!areMoreImagesAvailable}
isLoading={isLoading} isLoading={isLoading}
loadingText="Loading" loadingText="Loading"

View File

@ -31,6 +31,7 @@ import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { Scheduler } from 'app/constants'; import { Scheduler } from 'app/constants';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
type MetadataItemProps = { type MetadataItemProps = {
isLink?: boolean; isLink?: boolean;
@ -53,6 +54,11 @@ const MetadataItem = ({
withCopy = false, withCopy = false,
}: MetadataItemProps) => { }: MetadataItemProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
if (!value) {
return null;
}
return ( return (
<Flex gap={2}> <Flex gap={2}>
{onClick && ( {onClick && (
@ -115,6 +121,21 @@ const memoEqualityCheck = (
*/ */
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
useHotkeys('esc', () => { useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false)); dispatch(setShouldShowImageDetails(false));
@ -161,52 +182,53 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{metadata.type && ( {metadata.type && (
<MetadataItem label="Invocation type" value={metadata.type} /> <MetadataItem label="Invocation type" value={metadata.type} />
)} )}
{metadata.width && ( {sessionId && <MetadataItem label="Session ID" value={sessionId} />}
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => dispatch(setWidth(Number(metadata.width)))}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => dispatch(setHeight(Number(metadata.height)))}
/>
)}
{metadata.model && (
<MetadataItem label="Model" value={metadata.model} />
)}
{metadata.positive_conditioning && ( {metadata.positive_conditioning && (
<MetadataItem <MetadataItem
label="Prompt" label="Positive Prompt"
labelPosition="top" labelPosition="top"
value={ value={metadata.positive_conditioning}
typeof metadata.positive_conditioning === 'string' onClick={() =>
? metadata.positive_conditioning recallPositivePrompt(metadata.positive_conditioning)
: promptToString(metadata.positive_conditioning)
} }
onClick={() => setPositivePrompt(metadata.positive_conditioning!)}
/> />
)} )}
{metadata.negative_conditioning && ( {metadata.negative_conditioning && (
<MetadataItem <MetadataItem
label="Prompt" label="Negative Prompt"
labelPosition="top" labelPosition="top"
value={ value={metadata.negative_conditioning}
typeof metadata.negative_conditioning === 'string' onClick={() =>
? metadata.negative_conditioning recallNegativePrompt(metadata.negative_conditioning)
: promptToString(metadata.negative_conditioning)
} }
onClick={() => setNegativePrompt(metadata.negative_conditioning!)}
/> />
)} )}
{metadata.seed !== undefined && ( {metadata.seed !== undefined && (
<MetadataItem <MetadataItem
label="Seed" label="Seed"
value={metadata.seed} value={metadata.seed}
onClick={() => dispatch(setSeed(Number(metadata.seed)))} onClick={() => recallSeed(metadata.seed)}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model}
onClick={() => recallModel(metadata.model)}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => recallWidth(metadata.width)}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => recallHeight(metadata.height)}
/> />
)} )}
{/* {metadata.threshold !== undefined && ( {/* {metadata.threshold !== undefined && (
@ -227,23 +249,21 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem <MetadataItem
label="Scheduler" label="Scheduler"
value={metadata.scheduler} value={metadata.scheduler}
onClick={() => onClick={() => recallScheduler(metadata.scheduler)}
dispatch(setScheduler(metadata.scheduler as Scheduler))
}
/> />
)} )}
{metadata.steps && ( {metadata.steps && (
<MetadataItem <MetadataItem
label="Steps" label="Steps"
value={metadata.steps} value={metadata.steps}
onClick={() => dispatch(setSteps(Number(metadata.steps)))} onClick={() => recallSteps(metadata.steps)}
/> />
)} )}
{metadata.cfg_scale !== undefined && ( {metadata.cfg_scale !== undefined && (
<MetadataItem <MetadataItem
label="CFG scale" label="CFG scale"
value={metadata.cfg_scale} value={metadata.cfg_scale}
onClick={() => dispatch(setCfgScale(Number(metadata.cfg_scale)))} onClick={() => recallCfgScale(metadata.cfg_scale)}
/> />
)} )}
{/* {metadata.variations && metadata.variations.length > 0 && ( {/* {metadata.variations && metadata.variations.length > 0 && (
@ -284,9 +304,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem <MetadataItem
label="Image to image strength" label="Image to image strength"
value={metadata.strength} value={metadata.strength}
onClick={() => onClick={() => recallStrength(metadata.strength)}
dispatch(setImg2imgStrength(Number(metadata.strength)))
}
/> />
)} )}
{/* {metadata.fit && ( {/* {metadata.fit && (

View File

@ -9,6 +9,10 @@ import { gallerySelector } from '../store/gallerySelectors';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { imageSelected } from '../store/gallerySlice'; import { imageSelected } from '../store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import {
selectFilteredImagesAsObject,
selectFilteredImagesIds,
} from '../store/imagesSlice';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%', height: '100%',
@ -21,9 +25,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
}; };
export const nextPrevImageButtonsSelector = createSelector( export const nextPrevImageButtonsSelector = createSelector(
[(state: RootState) => state, gallerySelector], [
(state, gallery) => { (state: RootState) => state,
const { selectedImage, currentCategory } = gallery; gallerySelector,
selectFilteredImagesAsObject,
selectFilteredImagesIds,
],
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
const { selectedImage } = gallery;
if (!selectedImage) { if (!selectedImage) {
return { return {
@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector(
}; };
} }
const currentImageIndex = state[currentCategory].ids.findIndex( const currentImageIndex = filteredImageIds.findIndex(
(i) => i === selectedImage.image_name (i) => i === selectedImage.image_name
); );
const nextImageIndex = clamp( const nextImageIndex = clamp(
currentImageIndex + 1, currentImageIndex + 1,
0, 0,
state[currentCategory].ids.length - 1 filteredImageIds.length - 1
); );
const prevImageIndex = clamp( const prevImageIndex = clamp(
currentImageIndex - 1, currentImageIndex - 1,
0, 0,
state[currentCategory].ids.length - 1 filteredImageIds.length - 1
); );
const nextImageId = state[currentCategory].ids[nextImageIndex]; const nextImageId = filteredImageIds[nextImageIndex];
const prevImageId = state[currentCategory].ids[prevImageIndex]; const prevImageId = filteredImageIds[prevImageIndex];
const nextImage = state[currentCategory].entities[nextImageId]; const nextImage = filteredImagesAsObject[nextImageId];
const prevImage = state[currentCategory].entities[prevImageId]; const prevImage = filteredImagesAsObject[prevImageId];
const imagesLength = state[currentCategory].ids.length; const imagesLength = filteredImageIds.length;
return { return {
isOnFirstImage: currentImageIndex === 0, isOnFirstImage: currentImageIndex === 0,

View File

@ -1,33 +1,18 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { ImageType } from 'services/api'; import { selectImagesEntities } from '../store/imagesSlice';
import { selectResultsEntities } from '../store/resultsSlice'; import { useCallback } from 'react';
import { selectUploadsEntities } from '../store/uploadsSlice';
const useGetImageByNameSelector = createSelector( const useGetImageByName = () => {
[selectResultsEntities, selectUploadsEntities], const images = useAppSelector(selectImagesEntities);
(allResults, allUploads) => { return useCallback(
return { allResults, allUploads }; (name: string | undefined) => {
if (!name) {
return;
} }
); return images[name];
},
const useGetImageByNameAndType = () => { [images]
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector); );
return (name: string, type: ImageType) => {
if (type === 'results') {
const resultImagesResult = allResults[name];
if (resultImagesResult) {
return resultImagesResult;
}
}
if (type === 'uploads') {
const userImagesResult = allUploads[name];
if (userImagesResult) {
return userImagesResult;
}
}
};
}; };
export default useGetImageByNameAndType; export default useGetImageByName;

View File

@ -1,9 +1,9 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageNameAndType } from 'features/parameters/store/actions'; import { ImageNameAndOrigin } from 'features/parameters/store/actions';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
export const requestedImageDeletion = createAction< export const requestedImageDeletion = createAction<
ImageDTO | ImageNameAndType | undefined ImageDTO | ImageNameAndOrigin | undefined
>('gallery/requestedImageDeletion'); >('gallery/requestedImageDeletion');
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas'); export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');

View File

@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist * Gallery slice persist denylist
*/ */
export const galleryPersistDenylist: (keyof GalleryState)[] = [ export const galleryPersistDenylist: (keyof GalleryState)[] = [
'currentCategory',
'shouldAutoSwitchToNewImages', 'shouldAutoSwitchToNewImages',
]; ];

View File

@ -1,10 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from '../../../services/thunks/gallery';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { imageUpserted } from './imagesSlice';
type GalleryImageObjectFitType = 'contain' | 'cover'; type GalleryImageObjectFitType = 'contain' | 'cover';
@ -14,7 +11,6 @@ export interface GalleryState {
galleryImageObjectFit: GalleryImageObjectFitType; galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean; shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean; shouldUseSingleGalleryColumn: boolean;
currentCategory: 'results' | 'uploads';
} }
export const initialGalleryState: GalleryState = { export const initialGalleryState: GalleryState = {
@ -22,7 +18,6 @@ export const initialGalleryState: GalleryState = {
galleryImageObjectFit: 'cover', galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true, shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false, shouldUseSingleGalleryColumn: false,
currentCategory: 'results',
}; };
export const gallerySlice = createSlice({ export const gallerySlice = createSlice({
@ -46,12 +41,6 @@ export const gallerySlice = createSlice({
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => { setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitchToNewImages = action.payload; state.shouldAutoSwitchToNewImages = action.payload;
}, },
setCurrentCategory: (
state,
action: PayloadAction<'results' | 'uploads'>
) => {
state.currentCategory = action.payload;
},
setShouldUseSingleGalleryColumn: ( setShouldUseSingleGalleryColumn: (
state, state,
action: PayloadAction<boolean> action: PayloadAction<boolean>
@ -59,37 +48,10 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload; state.shouldUseSingleGalleryColumn = action.payload;
}, },
}, },
extraReducers(builder) { extraReducers: (builder) => {
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => { builder.addCase(imageUpserted, (state, action) => {
// rehydrate selectedImage URL when results list comes in if (state.shouldAutoSwitchToNewImages) {
// solves case when outdated URL is in local storage state.selectedImage = action.payload;
const selectedImage = state.selectedImage;
if (selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === selectedImage.image_name
);
if (selectedImageInResults) {
selectedImage.image_url = selectedImageInResults.image_url;
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
state.selectedImage = selectedImage;
}
}
});
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
const selectedImage = state.selectedImage;
if (selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === selectedImage.image_name
);
if (selectedImageInResults) {
selectedImage.image_url = selectedImageInResults.image_url;
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
state.selectedImage = selectedImage;
}
} }
}); });
}, },
@ -101,7 +63,6 @@ export const {
setGalleryImageObjectFit, setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages, setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn, setShouldUseSingleGalleryColumn,
setCurrentCategory,
} = gallerySlice.actions; } = gallerySlice.actions;
export default gallerySlice.reducer; export default gallerySlice.reducer;

View File

@ -0,0 +1,135 @@
import {
PayloadAction,
createEntityAdapter,
createSelector,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageCategory, ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
import { isString, keyBy } from 'lodash-es';
import { receivedPageOfImages } from 'services/thunks/image';
export const imagesAdapter = createEntityAdapter<ImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = [
'control',
'mask',
'user',
'other',
];
type AdditionaImagesState = {
offset: number;
limit: number;
total: number;
isLoading: boolean;
categories: ImageCategory[];
};
export const initialImagesState =
imagesAdapter.getInitialState<AdditionaImagesState>({
offset: 0,
limit: 0,
total: 0,
isLoading: false,
categories: IMAGE_CATEGORIES,
});
export type ImagesState = typeof initialImagesState;
const imagesSlice = createSlice({
name: 'images',
initialState: initialImagesState,
reducers: {
imageUpserted: (state, action: PayloadAction<ImageDTO>) => {
imagesAdapter.upsertOne(state, action.payload);
},
imageRemoved: (state, action: PayloadAction<string | ImageDTO>) => {
if (isString(action.payload)) {
imagesAdapter.removeOne(state, action.payload);
return;
}
imagesAdapter.removeOne(state, action.payload.image_name);
},
imageCategoriesChanged: (state, action: PayloadAction<ImageCategory[]>) => {
state.categories = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(receivedPageOfImages.pending, (state) => {
state.isLoading = true;
});
builder.addCase(receivedPageOfImages.rejected, (state) => {
state.isLoading = false;
});
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false;
const { items, offset, limit, total } = action.payload;
state.offset = offset;
state.limit = limit;
state.total = total;
imagesAdapter.upsertMany(state, items);
});
},
});
export const {
selectAll: selectImagesAll,
selectById: selectImagesById,
selectEntities: selectImagesEntities,
selectIds: selectImagesIds,
selectTotal: selectImagesTotal,
} = imagesAdapter.getSelectors<RootState>((state) => state.images);
export const { imageUpserted, imageRemoved, imageCategoriesChanged } =
imagesSlice.actions;
export default imagesSlice.reducer;
export const selectFilteredImagesAsArray = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return selectImagesAll(state).filter((i) =>
categories.includes(i.image_category)
);
}
);
export const selectFilteredImagesAsObject = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return keyBy(
selectImagesAll(state).filter((i) =>
categories.includes(i.image_category)
),
'image_name'
);
}
);
export const selectFilteredImagesIds = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return selectImagesAll(state)
.filter((i) => categories.includes(i.image_category))
.map((i) => i.image_name);
}
);

View File

@ -1,8 +0,0 @@
import { ResultsState } from './resultsSlice';
/**
* Results slice persist denylist
*
* Currently denylisting results slice entirely, see `serialize.ts`
*/
export const resultsPersistDenylist: (keyof ResultsState)[] = [];

View File

@ -1,125 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import {
imageDeleted,
imageMetadataReceived,
imageUrlsReceived,
} from 'services/thunks/image';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'results';
};
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
type AdditionalResultsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { page, pages } = action.payload;
// We know these will all be of the results type, but it's not represented in the API types
const items = action.payload.items as ResultsImageDTO[];
resultsAdapter.setMany(state, items);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Image Metadata Received - FULFILLED
*/
builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
const { image_type } = action.payload;
if (image_type === 'results') {
resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
}
});
/**
* Image URLs Received - FULFILLED
*/
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_type, image_url, thumbnail_url } =
action.payload;
if (image_type === 'results') {
resultsAdapter.updateOne(state, {
id: image_name,
changes: {
image_url: image_url,
thumbnail_url: thumbnail_url,
},
});
}
});
/**
* Delete Image - PENDING
* Pre-emptively remove the image from the gallery
*/
builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg;
if (imageType === 'results') {
resultsAdapter.removeOne(state, imageName);
}
});
},
});
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@ -1,8 +0,0 @@
import { UploadsState } from './uploadsSlice';
/**
* Uploads slice persist denylist
*
* Currently denylisting uploads slice entirely, see `serialize.ts`
*/
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];

View File

@ -1,111 +0,0 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'uploads';
};
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
export const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
});
export type UploadsState = typeof initialUploadsState;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadAdded: uploadsAdapter.upsertOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { page, pages } = action.payload;
// We know these will all be of the uploads type, but it's not represented in the API types
const items = action.payload.items as UploadsImageDTO[];
uploadsAdapter.setMany(state, items);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Image URLs Received - FULFILLED
*/
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_type, image_url, thumbnail_url } =
action.payload;
if (image_type === 'uploads') {
uploadsAdapter.updateOne(state, {
id: image_name,
changes: {
image_url: image_url,
thumbnail_url: thumbnail_url,
},
});
}
});
/**
* Delete Image - pending
* Pre-emptively remove the image from the gallery
*/
builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg;
if (imageType === 'uploads') {
uploadsAdapter.removeOne(state, imageName);
}
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@ -10,6 +10,7 @@ import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComp
import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
@ -130,6 +131,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'control' && template.type === 'control') {
return (
<ControlInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') { if (type === 'model' && template.type === 'model') {
return ( return (
<ModelInputFieldComponent <ModelInputFieldComponent

Some files were not shown because too many files have changed in this diff Show More