Merge branch 'main' into doc_updates_23

This commit is contained in:
Lincoln Stein 2023-05-29 08:13:12 -04:00
commit 00cb8a0c64
166 changed files with 3796 additions and 677 deletions

14
.github/CODEOWNERS vendored
View File

@ -2,7 +2,7 @@
/.github/workflows/ @lstein @blessedcoolant /.github/workflows/ @lstein @blessedcoolant
# documentation # documentation
/docs/ @lstein @tildebyte @blessedcoolant /docs/ @lstein @blessedcoolant @hipsterusername
/mkdocs.yml @lstein @blessedcoolant /mkdocs.yml @lstein @blessedcoolant
# nodes # nodes
@ -18,17 +18,17 @@
/invokeai/version @lstein @blessedcoolant /invokeai/version @lstein @blessedcoolant
# web ui # web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein /invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
/invokeai/backend @blessedcoolant @psychedelicious @lstein /invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing # generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 /invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
# front ends # front ends
/invokeai/frontend/CLI @lstein /invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr /invokeai/frontend/install @lstein @ebr
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername /invokeai/frontend/merge @lstein @blessedcoolant
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername /invokeai/frontend/training @lstein @blessedcoolant
/invokeai/frontend/web @psychedelicious @blessedcoolant /invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp

View File

@ -125,6 +125,7 @@ jobs:
--no-nsfw_checker --no-nsfw_checker
--precision=float32 --precision=float32
--always_use_cpu --always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }} --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }} --from_file ${{ env.TEST_PROMPTS }}

View File

@ -1,5 +1,6 @@
import io import io
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
@ -7,7 +8,11 @@ from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ImageType,
) )
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -27,10 +32,17 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
) )
async def upload_image( async def upload_image(
file: UploadFile, file: UploadFile,
image_type: ImageType,
request: Request, request: Request,
response: Response, response: Response,
image_category: ImageCategory = ImageCategory.GENERAL, image_category: ImageCategory = Query(
default=ImageCategory.GENERAL, description="The category of the image"
),
is_intermediate: bool = Query(
default=False, description="Whether this is an intermediate image"
),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):
@ -46,9 +58,11 @@ async def upload_image(
try: try:
image_dto = ApiDependencies.invoker.services.images.create( image_dto = ApiDependencies.invoker.services.images.create(
pil_image, image=pil_image,
image_type, image_type=ImageType.UPLOAD,
image_category, image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
) )
response.status_code = 201 response.status_code = 201
@ -61,7 +75,7 @@ async def upload_image(
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_type: ImageType = Query(description="The type of image to delete"), image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
@ -73,6 +87,28 @@ async def delete_image(
pass pass
@images_router.patch(
"/{image_type}/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_type: ImageType = Path(description="The type of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_type, image_name, image_changes
)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/metadata", "/{image_type}/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
@ -85,9 +121,7 @@ async def get_image_metadata(
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto( return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
image_type, image_name
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -113,9 +147,7 @@ async def get_image_full(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
image_type, image_name
)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -13,10 +13,13 @@ from typing import (
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -188,6 +191,9 @@ def invoke_all(context: CliContext):
raise SessionError() raise SessionError()
logger = logger.InvokeAILogger.getLogger()
def invoke_cli(): def invoke_cli():
# this gets the basic configuration # this gets the basic configuration
config = get_invokeai_config() config = get_invokeai_config()
@ -206,24 +212,43 @@ def invoke_cli():
events = EventServiceBase() events = EventServiceBase()
output_folder = config.output_path output_folder = config.output_path
metadata = PngMetadataService()
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") if config.use_memory_db:
db_location = ":memory:"
else:
db_location = os.path.join(output_folder, "invokeai.db")
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
graph_execution_manager=graph_execution_manager,
)
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata), images=images,
metadata=metadata,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs" filename=db_location, table_name="graphs"
), ),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=graph_execution_manager,
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(), processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger), restoration=RestorationServices(config,logger=logger),
logger=logger, logger=logger,

View File

@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off #fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on #fmt: on
@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False):
"image", "image",
"latents", "latents",
"model", "model",
"control",
], ],
] ]
tags: List[str] tags: List[str]

View File

@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs # Outputs
collection: list[int] = Field(default=[], description="The int collection") collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""
type: Literal["float_collection"] = "float_collection"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""

View File

@ -0,0 +1,428 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
import numpy as np
from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType, ImageCategory
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from controlnet_aux import (
CannyDetector,
HEDdetector,
LineartDetector,
LineartAnimeDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
ContentShuffleDetector,
ZoeDetector,
MediapipeFaceDetector,
)
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
# lllyasviel sd v1.5, ControlNet v1.0 models
##############################################
"lllyasviel/sd-controlnet-canny",
"lllyasviel/sd-controlnet-depth",
"lllyasviel/sd-controlnet-hed",
"lllyasviel/sd-controlnet-seg",
"lllyasviel/sd-controlnet-openpose",
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
"lllyasviel/control_v11p_sd15_canny",
"lllyasviel/control_v11p_sd15_openpose",
"lllyasviel/control_v11p_sd15_seg",
# "lllyasviel/control_v11p_sd15_depth", # broken
"lllyasviel/control_v11f1p_sd15_depth",
"lllyasviel/control_v11p_sd15_normalbae",
"lllyasviel/control_v11p_sd15_scribble",
"lllyasviel/control_v11p_sd15_mlsd",
"lllyasviel/control_v11p_sd15_softedge",
"lllyasviel/control_v11p_sd15s2_lineart_anime",
"lllyasviel/control_v11p_sd15_lineart",
"lllyasviel/control_v11p_sd15_inpaint",
# "lllyasviel/control_v11u_sd15_tile",
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
"thibaud/controlnet-sd21-openpose-diffusers",
"thibaud/controlnet-sd21-canny-diffusers",
"thibaud/controlnet-sd21-depth-diffusers",
"thibaud/controlnet-sd21-scribble-diffusers",
"thibaud/controlnet-sd21-hed-diffusers",
"thibaud/controlnet-sd21-zoedepth-diffusers",
"thibaud/controlnet-sd21-color-diffusers",
"thibaud/controlnet-sd21-openposev2-diffusers",
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
# hacked t2l to split to model & subfolder if format is "model,subfolder"
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="processed image")
control_model: Optional[str] = Field(default=None, description="control model used")
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied")
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
}
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info dict")
# fmt: on
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="image to process")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied")
# fmt: on
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet"""
# fmt: off
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = Field(default=None, description="image to process")
# fmt: on
def run_processor(self, image):
# superclass just passes through image without processing
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
image=processed_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,
width = image_dto.width,
# height=processed_image.height,
height = image_dto.height,
# mode=processed_image.mode,
)
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
# fmt: on
def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
return processed_image
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode")
# fmt: on
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
coarse: bool = Field(default=False, description="whether to use coarse mode")
# fmt: on
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processed_image = lineart_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse)
return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# fmt: on
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# fmt: on
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = openpose_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
hand_and_face=self.hand_and_face,
)
return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
# fmt: on
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
# fmt: on
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d)
return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode")
# fmt: on
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble)
return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
f: Union[int | None] = Field(default=256, ge=0, description="cont")
# fmt: on
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f
)
return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
# fmt: on
def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image

View File

@ -57,10 +57,11 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_inpainted, image=image_inpainted,
image_type=ImageType.INTERMEDIATE, image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(

View File

@ -4,7 +4,9 @@ from functools import partial
from typing import Literal, Optional, Union, get_args from typing import Literal, Optional, Union, get_args
import numpy as np import numpy as np
from diffusers import ControlNetModel
from torch import Tensor from torch import Tensor
import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -56,8 +58,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on # fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
@ -78,17 +83,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# Handle invalid model parameter # Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model) model = choose_model(context.services.model_manager, self.model)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get(
self.control_image.image_type, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id context.graph_execution_state_id
) )
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate( txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
@ -101,6 +124,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -181,6 +205,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -296,6 +321,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(

View File

@ -143,6 +143,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -204,6 +205,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -242,6 +244,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.MASK, image_category=ImageCategory.MASK,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return MaskOutput( return MaskOutput(
@ -280,6 +283,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -318,6 +322,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -356,6 +361,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -397,6 +403,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -437,6 +444,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -482,6 +490,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(

View File

@ -149,6 +149,7 @@ class InfillColorInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -193,6 +194,7 @@ class InfillTileInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
@ -230,6 +232,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(

View File

@ -1,8 +1,11 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random import random
from typing import Literal, Optional, Union
import einops import einops
from typing import Literal, Optional, Union, List
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
import torch import torch
@ -11,14 +14,18 @@ from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from .controlnet_image_processors import ControlField
from ...backend.model_management.model_manager import ModelManager from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_file_storage import ImageType from ..services.image_file_storage import ImageType
@ -28,7 +35,7 @@ from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers import diffusers
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline, ControlNetModel
class LatentsField(BaseModel): class LatentsField(BaseModel):
@ -167,8 +174,9 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on # fmt: on
@ -179,7 +187,8 @@ class TextToLatentsInvocation(BaseInvocation):
"ui": { "ui": {
"tags": ["latents", "image"], "tags": ["latents", "image"],
"type_hints": { "type_hints": {
"model": "model" "model": "model",
"control": "control",
} }
}, },
} }
@ -238,6 +247,81 @@ class TextToLatentsInvocation(BaseInvocation):
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta) ).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data return conditioning_data
def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_type,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
@ -252,14 +336,19 @@ class TextToLatentsInvocation(BaseInvocation):
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)), latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback control_data=control_data, # list[ControlNetData]
callback=step_callback,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -285,7 +374,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"ui": { "ui": {
"tags": ["latents"], "tags": ["latents"],
"type_hints": { "type_hints": {
"model": "model" "model": "model",
"control": "control",
} }
}, },
} }
@ -304,6 +394,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
@ -318,6 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback
) )
@ -362,14 +458,21 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
# what happened to metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
torch.cuda.empty_cache() torch.cuda.empty_cache()
# new (post Image service refactor) way of using services to save image
# and gnenerate unique image_name
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image, image=image,
image_type=ImageType.RESULT, image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate
) )
return ImageOutput( return ImageOutput(
@ -413,6 +516,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents) context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -443,6 +547,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents) context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents)
@ -467,6 +572,9 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -487,6 +595,6 @@ class ImageToLatentsInvocation(BaseInvocation):
) )
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)

View File

@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
# fmt: on # fmt: on
class FloatOutput(BaseInvocationOutput):
"""A float output"""
# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig): class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers""" """Adds two numbers"""

View File

@ -3,7 +3,7 @@
from typing import Literal from typing import Literal
from pydantic import Field from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput from .math import IntOutput, FloatOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a) return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
#fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
#fmt: on
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)

View File

@ -43,10 +43,11 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.INTERMEDIATE, image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(

View File

@ -49,6 +49,7 @@ class UpscaleInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(

View File

@ -10,7 +10,6 @@ class ImageType(str, Enum, metaclass=MetaEnum):
RESULT = "results" RESULT = "results"
UPLOAD = "uploads" UPLOAD = "uploads"
INTERMEDIATE = "intermediates"
class InvalidImageTypeException(ValueError): class InvalidImageTypeException(ValueError):

View File

@ -353,6 +353,7 @@ setting environment variables INVOKEAI_<setting>.
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths') autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
@ -362,6 +363,7 @@ setting environment variables INVOKEAI_<setting>.
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths') lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models') embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
@ -511,7 +513,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings: def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
''' '''
This returns a singleton InvokeAIAppConfig configuration object. This returns a singleton InvokeAIAppConfig configuration object.
''' '''

View File

@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None node_input_field = node_inputs.get(field) or None
return node_input_field return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
if not t1_args:
# t1 is a single type
return t1 in t2_args
else:
# t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t):
t_args = get_args(t)
# If the type is a List
if get_origin(t) is list:
return True
# If the type is a Union
elif t_args:
# Check if any of the types in the Union is a List
for arg in t_args:
if get_origin(arg) is list:
return True
return False
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type: if not from_type:
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type): if to_type in get_args(from_type):
return True return True
if not issubclass(from_type, to_type): # if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False return False
else: else:
return False return False
@ -694,7 +724,11 @@ class Graph(BaseModel):
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists # Verify that all outputs are lists
if not all((get_origin(f) == list for f in output_fields)): # if not all((get_origin(f) == list for f in output_fields)):
# return False
# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
return False return False
# Verify that all outputs match the input type (are a base class or the same class) # Verify that all outputs match the input type (are a base class or the same class)

View File

@ -12,6 +12,7 @@ from invokeai.app.models.image import (
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageRecordChanges,
deserialize_image_record, deserialize_image_record,
) )
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
@ -49,6 +50,16 @@ class ImageRecordStorageBase(ABC):
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def update(
self,
image_name: str,
image_type: ImageType,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
pass
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
@ -78,6 +89,7 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass
@ -125,6 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT, session_id TEXT,
node_id TEXT, node_id TEXT,
metadata TEXT, metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger -- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
@ -193,6 +206,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result)) return deserialize_image_record(dict(result))
def update(
self,
image_name: str,
image_type: ImageType,
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many( def get_many(
self, self,
image_type: ImageType, image_type: ImageType,
@ -265,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height: int, height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
@ -281,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height, height,
node_id, node_id,
session_id, session_id,
metadata metadata,
is_intermediate
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
@ -294,6 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id, node_id,
session_id, session_id,
metadata_json, metadata_json,
is_intermediate,
), ),
) )
self._conn.commit() self._conn.commit()

View File

@ -20,6 +20,7 @@ from invokeai.app.services.image_record_storage import (
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageDTO, ImageDTO,
ImageRecordChanges,
image_record_to_dto, image_record_to_dto,
) )
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
@ -31,7 +32,6 @@ from invokeai.app.services.image_file_storage import (
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -48,11 +48,21 @@ class ImageServiceABC(ABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None, intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@abstractmethod
def update(
self,
image_type: ImageType,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
@ -157,6 +167,7 @@ class ImageService(ImageServiceABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
if image_type not in ImageType: if image_type not in ImageType:
raise InvalidImageTypeException raise InvalidImageTypeException
@ -184,6 +195,8 @@ class ImageService(ImageServiceABC):
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id, session_id=session_id,
@ -217,6 +230,7 @@ class ImageService(ImageServiceABC):
created_at=created_at, created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None, deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO # Extra non-nullable fields for DTO
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -231,6 +245,23 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem saving image record and file") self._services.logger.error("Problem saving image record and file")
raise e raise e
def update(
self,
image_type: ImageType,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_type, changes)
return self.get_dto(image_type, image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_type, image_name)

View File

@ -1,18 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC from invokeai.app.services.invoker import InvocationProcessorABC
@ -20,32 +19,33 @@ if TYPE_CHECKING:
class InvocationServices: class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
events: EventServiceBase # TODO: Just forward-declared everything due to circular dependencies. Fix structure.
latents: LatentsStorageBase events: "EventServiceBase"
queue: InvocationQueueABC latents: "LatentsStorageBase"
model_manager: ModelManager queue: "InvocationQueueABC"
restoration: RestorationServices model_manager: "ModelManager"
configuration: InvokeAISettings restoration: "RestorationServices"
images: ImageService configuration: "InvokeAISettings"
images: "ImageService"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services # NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"] graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
def __init__( def __init__(
self, self,
model_manager: ModelManager, model_manager: "ModelManager",
events: EventServiceBase, events: "EventServiceBase",
logger: Logger, logger: "Logger",
latents: LatentsStorageBase, latents: "LatentsStorageBase",
images: ImageService, images: "ImageService",
queue: InvocationQueueABC, queue: "InvocationQueueABC",
graph_library: ItemStorageABC["LibraryGraph"], graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
restoration: RestorationServices, restoration: "RestorationServices",
configuration: InvokeAISettings = None, configuration: "InvokeAISettings",
): ):
self.model_manager = model_manager self.model_manager = model_manager
self.events = events self.events = events

View File

@ -1,6 +1,6 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Extra, Field, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
description="The deleted timestamp of the image." description="The deleted timestamp of the image."
) )
"""The deleted timestamp of the image.""" """The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
session_id: Optional[str] = Field( session_id: Optional[str] = Field(
default=None, default=None,
description="The session ID that generated this image, if it is a generated image.", description="The session ID that generated this image, if it is a generated image.",
@ -48,6 +50,25 @@ class ImageRecord(BaseModel):
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
"""
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
class ImageUrlsDTO(BaseModel): class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail.""" """The URLs for an image and its thumbnail."""
@ -95,6 +116,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at = image_dict.get("created_at", get_iso_timestamp()) created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata") raw_metadata = image_dict.get("metadata")
@ -115,4 +137,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate,
) )

View File

@ -1,10 +1,7 @@
import time import time
import traceback import traceback
from threading import Event, Thread, BoundedSemaphore from threading import Event, Thread, BoundedSemaphore
from typing import Any, TypeGuard
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.models.image import ImageType
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker

View File

@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self, def __init__(self,
model_info: dict, model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
**kwargs,
): ):
self.model_info=model_info self.model_info=model_info
self.params=params self.params=params
self.kwargs = kwargs
def generate(self, def generate(self,
prompt: str='', prompt: str='',
@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
model=model, model=model,
scheduler_name=generator_args.get('scheduler') scheduler_name=generator_args.get('scheduler')
) )
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
# get conditioning from prompt via Compel package
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
gen_class = self._generator_class() gen_class = self._generator_class()
generator = gen_class(model, self.params.precision) generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0: if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'), generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'), generator_args.get('variation_amount'),
@ -276,7 +281,7 @@ class Generator:
precision: str precision: str
model: DiffusionPipeline model: DiffusionPipeline
def __init__(self, model: DiffusionPipeline, precision: str): def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.model = model self.model = model
self.precision = precision self.precision = precision
self.seed = None self.seed = None

View File

@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
import PIL.Image import PIL.Image
import torch import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from ..stable_diffusion import ( from ..stable_diffusion import (
ConditioningData, ConditioningData,
PostprocessingSettings, PostprocessingSettings,
@ -13,8 +17,13 @@ from .base import Generator
class Txt2Img(Generator): class Txt2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision,
super().__init__(model, precision) control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
**kwargs):
self.control_model = control_model
if isinstance(self.control_model, list):
self.control_model = MultiControlNetModel(self.control_model)
super().__init__(model, precision, **kwargs)
@torch.no_grad() @torch.no_grad()
def get_make_image( def get_make_image(
@ -42,9 +51,12 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
self.perlin = perlin self.perlin = perlin
control_image = kwargs.get("control_image", None)
do_classifier_free_guidance = cfg_scale > 1.0
# noinspection PyTypeChecker # noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.control_model = self.control_model
pipeline.scheduler = sampler pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
@ -61,6 +73,37 @@ class Txt2Img(Generator):
), ),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(self.control_model, ControlNetModel):
control_image = pipeline.prepare_control_image(
image=control_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
elif isinstance(self.control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = self.model.prepare_control_image(
image=image_,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
images.append(image_)
control_image = images
kwargs["control_image"] = control_image
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image: def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
@ -68,6 +111,7 @@ class Txt2Img(Generator):
num_inference_steps=steps, num_inference_steps=steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback, callback=step_callback,
**kwargs,
) )
if ( if (

View File

@ -2,23 +2,29 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import math
import secrets import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
import einops import einops
import PIL.Image import PIL.Image
import numpy as np
from accelerate.utils import set_seed from accelerate.utils import set_seed
import psutil import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from compel import EmbeddingsProvider from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
) )
@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
) )
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?") raise AssertionError("why was that an empty generator?")
return result return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32", precision: str = "float32",
control_model: ControlNetModel = None,
): ):
super().__init__( super().__init__(
vae, vae,
@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
# FIXME: can't currently register control module
# control_model=control_model,
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent( self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True self.unet, self._unet_forward, is_running_diffusers=True
@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels) self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor): def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
""" """
@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None, run_id=None,
**kwargs,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**kwargs,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False): if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu') scheduler_device = torch.device('cpu')
@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
control_data=control_data,
**kwargs,
) )
return result.latents, result.attention_map_saver return result.latents, result.attention_map_saver
@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor, noise: torch.Tensor,
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
# print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step( step_output = self.step(
@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data,
**kwargs,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
**kwargs,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,
)
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, latent_model_input,
@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t, t,
text_embeddings, text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None, cross_attention_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
): ):
"""predict the noise residual""" """predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4: if is_inpainting_model(self.unet) and latents.size(1) == 4:
@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# First three args should be positional, not keywords, so torch hooks can see them. # First three args should be positional, not keywords, so torch hooks can see them.
return self.unet( return self.unet(
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
).sample ).sample
def img2img_from_embeddings( def img2img_from_embeddings(
@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
def prepare_control_image(
self,
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
batch_size=1,
num_images_per_prompt=1,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image

View File

@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: Optional[int] = None, step_index: Optional[int] = None,
total_step_count: Optional[int] = None, total_step_count: Optional[int] = None,
**kwargs,
): ):
""" """
:param x: current latents :param x: current latents
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning: if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
elif wants_cross_attention_control: elif wants_cross_attention_control:
( (
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning_sequentially( ) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
else: else:
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
) = self._apply_standard_conditioning( ) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning x, sigma, unconditioning, conditioning, **kwargs,
) )
combined_next_x = self._combine( combined_next_x = self._combine(
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path # fast batched path
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning: torch.Tensor, unconditioning: torch.Tensor,
conditioning: torch.Tensor, conditioning: torch.Tensor,
**kwargs,
): ):
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps": if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug. # prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone() conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict) assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict) assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
else: else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback( unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings x_twice, sigma_twice, both_conditionings, **kwargs,
).chunk(2) ).chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
if self.is_running_diffusers: if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers( return self._apply_cross_attention_controlled_conditioning__diffusers(
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
else: else:
return self._apply_cross_attention_controlled_conditioning__compvis( return self._apply_cross_attention_controlled_conditioning__compvis(
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
) )
def _apply_cross_attention_controlled_conditioning__diffusers( def _apply_cross_attention_controlled_conditioning__diffusers(
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
unconditioning, unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
# do requested cross attention types for conditioning (positive prompt) # do requested cross attention types for conditioning (positive prompt)
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
sigma, sigma,
conditioning, conditioning,
{"swap_cross_attn_context": cross_attn_processor_context}, {"swap_cross_attn_context": cross_attn_processor_context},
**kwargs,
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
unconditioning, unconditioning,
conditioning, conditioning,
cross_attention_control_types_to_do, cross_attention_control_types_to_do,
**kwargs,
): ):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS) # slower non-batched path (20% slower on mac MPS)
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
try: try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
# process x using the original prompt, saving the attention maps # process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do) # print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do: for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type) context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning) _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
context.clear_requests(cleanup=False) context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied # process x again, using the saved attention maps to control where self.edited_conditioning will be applied
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
self.conditioning.cross_attention_control_args.edited_conditioning self.conditioning.cross_attention_control_args.edited_conditioning
) )
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning x, sigma, edited_conditioning, **kwargs,
) )
context.clear_requests(cleanup=True) context.clear_requests(cleanup=True)

View File

@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store'; import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete'; import {
import { addImageUploadedListener } from './listeners/imageUploaded'; addImageUploadedFulfilledListener,
import { addRequestedImageDeletionListener } from './listeners/imageDeleted'; addImageUploadedRejectedListener,
} from './listeners/imageUploaded';
import {
addImageDeletedFulfilledListener,
addImageDeletedPendingListener,
addImageDeletedRejectedListener,
addRequestedImageDeletionListener,
} from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
@ -19,6 +26,47 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged'; import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress';
import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete';
import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete';
import { addInvocationErrorListener } from './listeners/socketio/invocationError';
import { addInvocationStartedListener } from './listeners/socketio/invocationStarted';
import { addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
import {
addImageMetadataReceivedFulfilledListener,
addImageMetadataReceivedRejectedListener,
} from './listeners/imageMetadataReceived';
import {
addImageUrlsReceivedFulfilledListener,
addImageUrlsReceivedRejectedListener,
} from './listeners/imageUrlsReceived';
import {
addSessionCreatedFulfilledListener,
addSessionCreatedPendingListener,
addSessionCreatedRejectedListener,
} from './listeners/sessionCreated';
import {
addSessionInvokedFulfilledListener,
addSessionInvokedPendingListener,
addSessionInvokedRejectedListener,
} from './listeners/sessionInvoked';
import {
addSessionCanceledFulfilledListener,
addSessionCanceledPendingListener,
addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled';
import {
addReceivedResultImagesPageFulfilledListener,
addReceivedResultImagesPageRejectedListener,
} from './listeners/receivedResultImagesPage';
import {
addReceivedUploadImagesPageFulfilledListener,
addReceivedUploadImagesPageRejectedListener,
} from './listeners/receivedUploadImagesPage';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -38,17 +86,67 @@ export type AppListenerEffect = ListenerEffect<
AppDispatch AppDispatch
>; >;
addImageUploadedListener(); // Image uploaded
addInitialImageSelectedListener(); addImageUploadedFulfilledListener();
addImageResultReceivedListener(); addImageUploadedRejectedListener();
addRequestedImageDeletionListener();
addInitialImageSelectedListener();
// Image deleted
addRequestedImageDeletionListener();
addImageDeletedPendingListener();
addImageDeletedFulfilledListener();
addImageDeletedRejectedListener();
// Image metadata
addImageMetadataReceivedFulfilledListener();
addImageMetadataReceivedRejectedListener();
// Image URLs
addImageUrlsReceivedFulfilledListener();
addImageUrlsReceivedRejectedListener();
// User Invoked
addUserInvokedCanvasListener(); addUserInvokedCanvasListener();
addUserInvokedNodesListener(); addUserInvokedNodesListener();
addUserInvokedTextToImageListener(); addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener(); addUserInvokedImageToImageListener();
addSessionReadyToInvokeListener();
// Canvas actions
addCanvasSavedToGalleryListener(); addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener(); addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener(); addCanvasCopiedToClipboardListener();
addCanvasMergedListener(); addCanvasMergedListener();
// socketio
addGeneratorProgressListener();
addGraphExecutionStateCompleteListener();
addInvocationCompleteListener();
addInvocationErrorListener();
addInvocationStartedListener();
addSocketConnectedListener();
addSocketDisconnectedListener();
addSocketSubscribedListener();
addSocketUnsubscribedListener();
// Session Created
addSessionCreatedPendingListener();
addSessionCreatedFulfilledListener();
addSessionCreatedRejectedListener();
// Session Invoked
addSessionInvokedPendingListener();
addSessionInvokedFulfilledListener();
addSessionInvokedRejectedListener();
// Session Canceled
addSessionCanceledPendingListener();
addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener();
// Gallery pages
addReceivedResultImagesPageFulfilledListener();
addReceivedResultImagesPageRejectedListener();
addReceivedUploadImagesPageFulfilledListener();
addReceivedUploadImagesPageRejectedListener();

View File

@ -52,7 +52,6 @@ export const addCanvasMergedListener = () => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
imageType: 'intermediates',
formData: { formData: {
file: new File([blob], filename, { type: 'image/png' }), file: new File([blob], filename, { type: 'image/png' }),
}, },
@ -65,7 +64,7 @@ export const addCanvasMergedListener = () => {
action.meta.arg.formData.file.name === filename action.meta.arg.formData.file.name === filename
); );
const mergedCanvasImage = payload.response; const mergedCanvasImage = payload;
dispatch( dispatch(
setMergedCanvas({ setMergedCanvas({

View File

@ -29,7 +29,6 @@ export const addCanvasSavedToGalleryListener = () => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
imageType: 'results',
formData: { formData: {
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
}, },

View File

@ -4,9 +4,14 @@ import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
import { resultsAdapter } from 'features/gallery/store/resultsSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
/**
* Called when the user requests an image deletion
*/
export const addRequestedImageDeletionListener = () => { export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: requestedImageDeletion, actionCreator: requestedImageDeletion,
@ -19,11 +24,6 @@ export const addRequestedImageDeletionListener = () => {
const { image_name, image_type } = image; const { image_name, image_type } = image;
if (image_type !== 'uploads' && image_type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${image_type}`);
return;
}
const selectedImageName = getState().gallery.selectedImage?.image_name; const selectedImageName = getState().gallery.selectedImage?.image_name;
if (selectedImageName === image_name) { if (selectedImageName === image_name) {
@ -57,3 +57,49 @@ export const addRequestedImageDeletionListener = () => {
}, },
}); });
}; };
/**
* Called when the actual delete request is sent to the server
*/
export const addImageDeletedPendingListener = () => {
startAppListening({
actionCreator: imageDeleted.pending,
effect: (action, { dispatch, getState }) => {
const { imageName, imageType } = action.meta.arg;
// Preemptively remove the image from the gallery
if (imageType === 'uploads') {
uploadsAdapter.removeOne(getState().uploads, imageName);
}
if (imageType === 'results') {
resultsAdapter.removeOne(getState().results, imageName);
}
},
});
};
/**
* Called on successful delete
*/
export const addImageDeletedFulfilledListener = () => {
startAppListening({
actionCreator: imageDeleted.fulfilled,
effect: (action, { dispatch, getState }) => {
moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted');
},
});
};
/**
* Called on failed delete
*/
export const addImageDeletedRejectedListener = () => {
startAppListening({
actionCreator: imageDeleted.rejected,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
{ data: { image: action.meta.arg } },
'Unable to delete image'
);
},
});
};

View File

@ -0,0 +1,43 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import {
ResultsImageDTO,
resultUpserted,
} from 'features/gallery/store/resultsSlice';
import {
UploadsImageDTO,
uploadUpserted,
} from 'features/gallery/store/uploadsSlice';
const moduleLog = log.child({ namespace: 'image' });
export const addImageMetadataReceivedFulfilledListener = () => {
startAppListening({
actionCreator: imageMetadataReceived.fulfilled,
effect: (action, { getState, dispatch }) => {
const image = action.payload;
moduleLog.debug({ data: { image } }, 'Image metadata received');
if (image.image_type === 'results') {
dispatch(resultUpserted(action.payload as ResultsImageDTO));
}
if (image.image_type === 'uploads') {
dispatch(uploadUpserted(action.payload as UploadsImageDTO));
}
},
});
};
export const addImageMetadataReceivedRejectedListener = () => {
startAppListening({
actionCreator: imageMetadataReceived.rejected,
effect: (action, { getState, dispatch }) => {
moduleLog.debug(
{ data: { image: action.meta.arg } },
'Problem receiving image metadata'
);
},
});
};

View File

@ -1,25 +1,31 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice'; import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { resultAdded } from 'features/gallery/store/resultsSlice'; import { resultUpserted } from 'features/gallery/store/resultsSlice';
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards'; import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
import { log } from 'app/logging/useLogger';
export const addImageUploadedListener = () => { const moduleLog = log.child({ namespace: 'image' });
export const addImageUploadedFulfilledListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> => predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) && imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates', action.payload.is_intermediate === false,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { response: image } = action.payload; const image = action.payload;
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
const state = getState(); const state = getState();
// Handle uploads
if (isUploadsImageDTO(image)) { if (isUploadsImageDTO(image)) {
dispatch(uploadAdded(image)); dispatch(uploadUpserted(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
@ -36,9 +42,26 @@ export const addImageUploadedListener = () => {
} }
} }
// Handle results
// TODO: Can this ever happen? I don't think so...
if (isResultsImageDTO(image)) { if (isResultsImageDTO(image)) {
dispatch(resultAdded(image)); dispatch(resultUpserted(image));
} }
}, },
}); });
}; };
export const addImageUploadedRejectedListener = () => {
startAppListening({
actionCreator: imageUploaded.rejected,
effect: (action, { dispatch }) => {
dispatch(
addToast({
title: 'Image Upload Failed',
description: action.error.message,
status: 'error',
})
);
},
});
};

View File

@ -0,0 +1,51 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageUrlsReceived } from 'services/thunks/image';
import { resultsAdapter } from 'features/gallery/store/resultsSlice';
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
const moduleLog = log.child({ namespace: 'image' });
export const addImageUrlsReceivedFulfilledListener = () => {
startAppListening({
actionCreator: imageUrlsReceived.fulfilled,
effect: (action, { getState, dispatch }) => {
const image = action.payload;
moduleLog.debug({ data: { image } }, 'Image URLs received');
const { image_type, image_name, image_url, thumbnail_url } = image;
if (image_type === 'results') {
resultsAdapter.updateOne(getState().results, {
id: image_name,
changes: {
image_url,
thumbnail_url,
},
});
}
if (image_type === 'uploads') {
uploadsAdapter.updateOne(getState().uploads, {
id: image_name,
changes: {
image_url,
thumbnail_url,
},
});
}
},
});
};
export const addImageUrlsReceivedRejectedListener = () => {
startAppListening({
actionCreator: imageUrlsReceived.rejected,
effect: (action, { getState, dispatch }) => {
moduleLog.debug(
{ data: { image: action.meta.arg } },
'Problem getting image URLs'
);
},
});
};

View File

@ -1,62 +0,0 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
imageMetadataReceived,
imageUrlsReceived,
} from 'services/thunks/image';
import { startAppListening } from '..';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => {
startAppListening({
predicate: (action) => {
if (
invocationComplete.match(action) &&
isImageOutput(action.payload.data.result)
) {
return true;
}
return false;
},
effect: async (action, { getState, dispatch, take }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_type } = result.image;
dispatch(
imageUrlsReceived({ imageName: image_name, imageType: image_type })
);
dispatch(
imageMetadataReceived({
imageName: image_name,
imageType: image_type,
})
);
// Handle canvas image
if (
graph_execution_state_id ===
getState().canvas.layerState.stagingArea.sessionId
) {
const [{ payload: image }] = await take(
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
dispatch(addImageToStagingArea(image));
}
}
},
});
};

View File

@ -0,0 +1,33 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedResultImagesPage } from 'services/thunks/gallery';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedResultImagesPageFulfilledListener = () => {
startAppListening({
actionCreator: receivedResultImagesPage.fulfilled,
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { page } },
`Received ${page.items.length} results`
);
},
});
};
export const addReceivedResultImagesPageRejectedListener = () => {
startAppListening({
actionCreator: receivedResultImagesPage.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
moduleLog.debug(
{ data: { error: serializeError(action.payload.error) } },
'Problem receiving results'
);
}
},
});
};

View File

@ -0,0 +1,33 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedUploadImagesPage } from 'services/thunks/gallery';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'gallery' });
export const addReceivedUploadImagesPageFulfilledListener = () => {
startAppListening({
actionCreator: receivedUploadImagesPage.fulfilled,
effect: (action, { getState, dispatch }) => {
const page = action.payload;
moduleLog.debug(
{ data: { page } },
`Received ${page.items.length} uploads`
);
},
});
};
export const addReceivedUploadImagesPageRejectedListener = () => {
startAppListening({
actionCreator: receivedUploadImagesPage.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
moduleLog.debug(
{ data: { error: serializeError(action.payload.error) } },
'Problem receiving uploads'
);
}
},
});
};

View File

@ -0,0 +1,48 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCanceled } from 'services/thunks/session';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionCanceledPendingListener = () => {
startAppListening({
actionCreator: sessionCanceled.pending,
effect: (action, { getState, dispatch }) => {
//
},
});
};
export const addSessionCanceledFulfilledListener = () => {
startAppListening({
actionCreator: sessionCanceled.fulfilled,
effect: (action, { getState, dispatch }) => {
const { sessionId } = action.meta.arg;
moduleLog.debug(
{ data: { sessionId } },
`Session canceled (${sessionId})`
);
},
});
};
export const addSessionCanceledRejectedListener = () => {
startAppListening({
actionCreator: sessionCanceled.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
const { arg, error } = action.payload;
moduleLog.error(
{
data: {
arg,
error: serializeError(error),
},
},
`Problem canceling session`
);
}
},
});
};

View File

@ -0,0 +1,45 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCreated } from 'services/thunks/session';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionCreatedPendingListener = () => {
startAppListening({
actionCreator: sessionCreated.pending,
effect: (action, { getState, dispatch }) => {
//
},
});
};
export const addSessionCreatedFulfilledListener = () => {
startAppListening({
actionCreator: sessionCreated.fulfilled,
effect: (action, { getState, dispatch }) => {
const session = action.payload;
moduleLog.debug({ data: { session } }, `Session created (${session.id})`);
},
});
};
export const addSessionCreatedRejectedListener = () => {
startAppListening({
actionCreator: sessionCreated.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
const { arg, error } = action.payload;
moduleLog.error(
{
data: {
arg,
error: serializeError(error),
},
},
`Problem creating session`
);
}
},
});
};

View File

@ -0,0 +1,48 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionInvoked } from 'services/thunks/session';
import { serializeError } from 'serialize-error';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionInvokedPendingListener = () => {
startAppListening({
actionCreator: sessionInvoked.pending,
effect: (action, { getState, dispatch }) => {
//
},
});
};
export const addSessionInvokedFulfilledListener = () => {
startAppListening({
actionCreator: sessionInvoked.fulfilled,
effect: (action, { getState, dispatch }) => {
const { sessionId } = action.meta.arg;
moduleLog.debug(
{ data: { sessionId } },
`Session invoked (${sessionId})`
);
},
});
};
export const addSessionInvokedRejectedListener = () => {
startAppListening({
actionCreator: sessionInvoked.rejected,
effect: (action, { getState, dispatch }) => {
if (action.payload) {
const { arg, error } = action.payload;
moduleLog.error(
{
data: {
arg,
error: serializeError(error),
},
},
`Problem invoking session`
);
}
},
});
};

View File

@ -0,0 +1,22 @@
import { startAppListening } from '..';
import { sessionInvoked } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'session' });
export const addSessionReadyToInvokeListener = () => {
startAppListening({
actionCreator: sessionReadyToInvoke,
effect: (action, { getState, dispatch }) => {
const { sessionId } = getState().system;
if (sessionId) {
moduleLog.debug(
{ sessionId },
`Session ready to invoke (${sessionId})})`
);
dispatch(sessionInvoked({ sessionId }));
}
},
});
};

View File

@ -0,0 +1,28 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { generatorProgress } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addGeneratorProgressListener = () => {
startAppListening({
actionCreator: generatorProgress,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
action.payload.data.graph_execution_state_id
) {
moduleLog.trace(
action.payload,
'Ignored generator progress for canceled session'
);
return;
}
moduleLog.trace(
action.payload,
`Generator progress (${action.payload.data.node.type})`
);
},
});
};

View File

@ -0,0 +1,17 @@
import { log } from 'app/logging/useLogger';
import { graphExecutionStateComplete } from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
export const addGraphExecutionStateCompleteListener = () => {
startAppListening({
actionCreator: graphExecutionStateComplete,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Session invocation complete (${action.payload.data.graph_execution_state_id})`
);
},
});
};

View File

@ -0,0 +1,74 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { invocationComplete } from 'services/events/actions';
import { imageMetadataReceived } from 'services/thunks/image';
import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image'];
export const addInvocationCompleteListener = () => {
startAppListening({
actionCreator: invocationComplete,
effect: async (action, { dispatch, getState, take }) => {
moduleLog.debug(
action.payload,
`Invocation complete (${action.payload.data.node.type})`
);
const sessionId = action.payload.data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
}
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_type } = result.image;
// Get its metadata
dispatch(
imageMetadataReceived({
imageName: image_name,
imageType: image_type,
})
);
const [{ payload: imageDTO }] = await take(
imageMetadataReceived.fulfilled.match
);
if (getState().gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(imageDTO));
}
// Handle canvas image
if (
graph_execution_state_id ===
getState().canvas.layerState.stagingArea.sessionId
) {
const [{ payload: image }] = await take(
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
dispatch(addImageToStagingArea(image));
}
dispatch(progressImageSet(null));
}
},
});
};

View File

@ -0,0 +1,17 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { invocationError } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationErrorListener = () => {
startAppListening({
actionCreator: invocationError,
effect: (action, { dispatch, getState }) => {
moduleLog.error(
action.payload,
`Invocation error (${action.payload.data.node.type})`
);
},
});
};

View File

@ -0,0 +1,28 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { invocationStarted } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addInvocationStartedListener = () => {
startAppListening({
actionCreator: invocationStarted,
effect: (action, { dispatch, getState }) => {
if (
getState().system.canceledSession ===
action.payload.data.graph_execution_state_id
) {
moduleLog.trace(
action.payload,
'Ignored invocation started for canceled session'
);
return;
}
moduleLog.debug(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
},
});
};

View File

@ -0,0 +1,43 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketConnected } from 'services/events/actions';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketConnectedListener = () => {
startAppListening({
actionCreator: socketConnected,
effect: (action, { dispatch, getState }) => {
const { timestamp } = action.payload;
moduleLog.debug({ timestamp }, 'Connected');
const { results, uploads, models, nodes, config } = getState();
const { disabledTabs } = config;
// These thunks need to be dispatch in middleware; cannot handle in a reducer
if (!results.ids.length) {
dispatch(receivedResultImagesPage());
}
if (!uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
if (!models.ids.length) {
dispatch(receivedModels());
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema());
}
},
});
};

View File

@ -0,0 +1,14 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketDisconnected } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketDisconnectedListener = () => {
startAppListening({
actionCreator: socketDisconnected,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(action.payload, 'Disconnected');
},
});
};

View File

@ -0,0 +1,17 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketSubscribed } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketSubscribedListener = () => {
startAppListening({
actionCreator: socketSubscribed,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Subscribed (${action.payload.sessionId}))`
);
},
});
};

View File

@ -0,0 +1,17 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { socketUnsubscribed } from 'services/events/actions';
const moduleLog = log.child({ namespace: 'socketio' });
export const addSocketUnsubscribedListener = () => {
startAppListening({
actionCreator: socketUnsubscribed,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
action.payload,
`Unsubscribed (${action.payload.sessionId})`
);
},
});
};

View File

@ -1,9 +1,9 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session'; import { sessionCreated } from 'services/thunks/session';
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image'; import { imageUpdated, imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api'; import { Graph } from 'services/api';
import { import {
@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
/** /**
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. * This listener is responsible invoking the canvas. This involves a number of steps:
* It is also responsible for uploading the base and mask layers to the server. *
* 1. Generate image blobs from the canvas layers
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
* 3. Build the canvas graph
* 4. Create the session with the graph
* 5. Upload the init image if necessary
* 6. Upload the mask image if necessary
* 7. Update the init and mask images with the session ID
* 8. Initialize the staging area if not yet initialized
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
*/ */
export const addUserInvokedCanvasListener = () => { export const addUserInvokedCanvasListener = () => {
startAppListening({ startAppListening({
@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => {
const { rangeNode, iterateNode, baseNode, edges } = graphComponents; const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
// Upload the base layer, to be used as init image // Assemble! Note that this graph *does not have the init or mask image set yet!*
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
// Upload the mask layer image
const maskFilename = `${uuidv4()}.png`;
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = { const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode, [rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode, [iterateNode.id]: iterateNode,
@ -136,15 +90,90 @@ export const addUserInvokedCanvasListener = () => {
const graph = { nodes, edges }; const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph)); dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
// Actually create the session moduleLog.debug({ data: graph }, 'Canvas graph built');
// If we are generating img2img or inpaint, we need to upload the init images
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
isIntermediate: true,
})
);
// Wait for the image to be uploaded
const [{ payload: baseImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
// Update the base node with the image name and type
baseNode.image = {
image_name: baseImageDTO.image_name,
image_type: baseImageDTO.image_type,
};
}
// For inpaint, we also need to upload the mask layer
if (baseNode.type === 'inpaint') {
const maskFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
isIntermediate: true,
})
);
// Wait for the mask to be uploaded
const [{ payload: maskImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
// Update the base node with the image name and type
baseNode.mask = {
image_name: maskImageDTO.image_name,
image_type: maskImageDTO.image_type,
};
}
// Create the session and wait for response
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
const sessionId = sessionCreatedAction.payload.id;
// Wait for the session to be invoked (this is just the HTTP request to start processing) // Associate the init image with the session, now that we have the session ID
const [{ meta }] = await take(sessionInvoked.fulfilled.match); if (
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
baseNode.image
) {
dispatch(
imageUpdated({
imageName: baseNode.image.image_name,
imageType: baseNode.image.image_type,
requestBody: { session_id: sessionId },
})
);
}
const { sessionId } = meta.arg; // Associate the mask image with the session, now that we have the session ID
if (baseNode.type === 'inpaint' && baseNode.mask) {
dispatch(
imageUpdated({
imageName: baseNode.mask.image_name,
imageType: baseNode.mask.image_type,
requestBody: { session_id: sessionId },
})
);
}
if (!state.canvas.layerState.stagingArea.boundingBox) { if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch( dispatch(
@ -158,7 +187,11 @@ export const addUserInvokedCanvasListener = () => {
); );
} }
// Flag the session with the canvas session ID
dispatch(canvasSessionIdChanged(sessionId)); dispatch(canvasSessionIdChanged(sessionId));
// We are ready to invoke the session!
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img', userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildImageToImageGraph(state); const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph)); dispatch(imageToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Image to Image graph built'); moduleLog.debug({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions'; import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes', userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildNodesGraph(state); const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph)); dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built'); moduleLog.debug({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' }); const moduleLog = log.child({ namespace: 'invoke' });
@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> => predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img', userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch, take }) => {
const state = getState(); const state = getState();
const graph = buildTextToImageGraph(state); const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph)); dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built');
moduleLog.debug({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph })); dispatch(sessionCreated({ graph }));
await take(sessionCreated.fulfilled.match);
dispatch(sessionReadyToInvoke());
}, },
}); });
}; };

View File

@ -16,6 +16,7 @@ import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
@ -46,6 +47,7 @@ const allReducers = {
ui: uiReducer, ui: uiReducer,
uploads: uploadsReducer, uploads: uploadsReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
// session: sessionReducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);

View File

@ -68,7 +68,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
async (file: File) => { async (file: File) => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
imageType: 'uploads',
formData: { file }, formData: { file },
activeTabName, activeTabName,
}) })

View File

@ -1,5 +1,10 @@
import { forEach, size } from 'lodash-es'; import { forEach, size } from 'lodash-es';
import { ImageField, LatentsField, ConditioningField } from 'services/api'; import {
ImageField,
LatentsField,
ConditioningField,
ControlField,
} from 'services/api';
const OBJECT_TYPESTRING = '[object Object]'; const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]'; const STRING_TYPESTRING = '[object String]';
@ -98,6 +103,24 @@ const parseConditioningField = (
}; };
}; };
const parseControlField = (controlField: unknown): ControlField | undefined => {
// Must be an object
if (!isObject(controlField)) {
return;
}
// A ControlField must have a `control`
if (!('control' in controlField)) {
return;
}
// console.log(typeof controlField.control);
// Build a valid ControlField
return {
control: controlField.control,
};
};
type NodeMetadata = { type NodeMetadata = {
[key: string]: [key: string]:
| string | string
@ -105,7 +128,8 @@ type NodeMetadata = {
| boolean | boolean
| ImageField | ImageField
| LatentsField | LatentsField
| ConditioningField; | ConditioningField
| ControlField;
}; };
type InvokeAIMetadata = { type InvokeAIMetadata = {
@ -131,7 +155,7 @@ export const parseNodeMetadata = (
return; return;
} }
// the only valid object types are ImageField, LatentsField and ConditioningField // the only valid object types are ImageField, LatentsField, ConditioningField, ControlField
if (isObject(nodeItem)) { if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) { if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem); const imageField = parseImageField(nodeItem);
@ -156,6 +180,14 @@ export const parseNodeMetadata = (
} }
return; return;
} }
if ('control' in nodeItem) {
const controlField = parseControlField(nodeItem);
if (controlField) {
parsed[nodeKey] = controlField;
}
return;
}
} }
// otherwise we accept any string, number or boolean // otherwise we accept any string, number or boolean

View File

@ -109,8 +109,9 @@ const currentImageButtonsSelector = createSelector(
isLightboxOpen, isLightboxOpen,
shouldHidePreview, shouldHidePreview,
image: selectedImage, image: selectedImage,
seed: selectedImage?.metadata?.invokeai?.node?.seed, seed: selectedImage?.metadata?.seed,
prompt: selectedImage?.metadata?.invokeai?.node?.prompt, prompt: selectedImage?.metadata?.positive_conditioning,
negativePrompt: selectedImage?.metadata?.negative_conditioning,
}; };
}, },
{ {
@ -245,13 +246,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
); );
const handleUseSeed = useCallback(() => { const handleUseSeed = useCallback(() => {
recallSeed(image?.metadata?.invokeai?.node?.seed); recallSeed(image?.metadata?.seed);
}, [image, recallSeed]); }, [image, recallSeed]);
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [image]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallPrompt(image?.metadata?.invokeai?.node?.prompt); recallPrompt(
image?.metadata?.positive_conditioning,
image?.metadata?.negative_conditioning
);
}, [image, recallPrompt]); }, [image, recallPrompt]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [image]);
@ -454,7 +458,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')} {t('parameters.copyImageToLink')}
</IAIButton> </IAIButton>
<Link download={true} href={getUrl(image?.url ?? '')}> <Link download={true} href={getUrl(image?.image_url ?? '')}>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%"> <IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')} {t('parameters.downloadImage')}
</IAIButton> </IAIButton>
@ -500,7 +504,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />} icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`} tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`} aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!image?.metadata?.invokeai?.node?.prompt} isDisabled={!image?.metadata?.positive_conditioning}
onClick={handleUsePrompt} onClick={handleUsePrompt}
/> />
@ -508,7 +512,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />} icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`} tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`} aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!image?.metadata?.invokeai?.node?.seed} isDisabled={!image?.metadata?.seed}
onClick={handleUseSeed} onClick={handleUseSeed}
/> />
@ -517,9 +521,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
tooltip={`${t('parameters.useAll')} (A)`} tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`} aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={ isDisabled={
!['txt2img', 'img2img', 'inpaint'].includes( // not sure what this list should be
String(image?.metadata?.invokeai?.node?.type) !['t2l', 'l2l', 'inpaint'].includes(String(image?.metadata?.type))
)
} }
onClick={handleClickUseAllParameters} onClick={handleClickUseAllParameters}
/> />

View File

@ -155,7 +155,10 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Recall parameters handlers // Recall parameters handlers
const handleRecallPrompt = useCallback(() => { const handleRecallPrompt = useCallback(() => {
recallPrompt(image.metadata?.positive_conditioning); recallPrompt(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [image, recallPrompt]); }, [image, recallPrompt]);
const handleRecallSeed = useCallback(() => { const handleRecallSeed = useCallback(() => {
@ -248,7 +251,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
icon={<IoArrowUndoCircleOutline />} icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters} onClickCapture={handleUseAllParameters}
isDisabled={ isDisabled={
!['txt2img', 'img2img', 'inpaint'].includes( // what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type) String(image?.metadata?.type)
) )
} }

View File

@ -8,29 +8,20 @@ import {
Text, Text,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import * as InvokeAI from 'app/types/invokeai';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString'; import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { import {
setCfgScale, setCfgScale,
setHeight, setHeight,
setImg2imgStrength, setImg2imgStrength,
setNegativePrompt, setNegativePrompt,
setPerlin,
setPositivePrompt, setPositivePrompt,
setScheduler, setScheduler,
setSeamless,
setSeed, setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps, setSteps,
setThreshold,
setWidth, setWidth,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { memo } from 'react'; import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
@ -39,7 +30,6 @@ import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { filter } from 'lodash-es';
import { Scheduler } from 'app/constants'; import { Scheduler } from 'app/constants';
type MetadataItemProps = { type MetadataItemProps = {
@ -126,8 +116,6 @@ const memoEqualityCheck = (
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const setBothPrompts = useSetBothPrompts();
useHotkeys('esc', () => { useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false)); dispatch(setShouldShowImageDetails(false));
}); });

View File

@ -33,7 +33,7 @@ export const nextPrevImageButtonsSelector = createSelector(
} }
const currentImageIndex = state[currentCategory].ids.findIndex( const currentImageIndex = state[currentCategory].ids.findIndex(
(i) => i === selectedImage.name (i) => i === selectedImage.image_name
); );
const nextImageIndex = clamp( const nextImageIndex = clamp(

View File

@ -1,14 +1,13 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; import {
PayloadAction,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
receivedResultImagesPage, receivedResultImagesPage,
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
} from 'services/thunks/gallery'; } from 'services/thunks/gallery';
import {
imageDeleted,
imageMetadataReceived,
imageUrlsReceived,
} from 'services/thunks/image';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator'; import { dateComparator } from 'common/util/dateComparator';
@ -26,6 +25,7 @@ type AdditionalResultsState = {
pages: number; pages: number;
isLoading: boolean; isLoading: boolean;
nextPage: number; nextPage: number;
upsertedImageCount: number;
}; };
export const initialResultsState = export const initialResultsState =
@ -34,6 +34,7 @@ export const initialResultsState =
pages: 0, pages: 0,
isLoading: false, isLoading: false,
nextPage: 0, nextPage: 0,
upsertedImageCount: 0,
}); });
export type ResultsState = typeof initialResultsState; export type ResultsState = typeof initialResultsState;
@ -42,7 +43,10 @@ const resultsSlice = createSlice({
name: 'results', name: 'results',
initialState: initialResultsState, initialState: initialResultsState,
reducers: { reducers: {
resultAdded: resultsAdapter.upsertOne, resultUpserted: (state, action: PayloadAction<ResultsImageDTO>) => {
resultsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
/** /**
@ -68,47 +72,6 @@ const resultsSlice = createSlice({
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false; state.isLoading = false;
}); });
/**
* Image Metadata Received - FULFILLED
*/
builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
const { image_type } = action.payload;
if (image_type === 'results') {
resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
}
});
/**
* Image URLs Received - FULFILLED
*/
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_type, image_url, thumbnail_url } =
action.payload;
if (image_type === 'results') {
resultsAdapter.updateOne(state, {
id: image_name,
changes: {
image_url: image_url,
thumbnail_url: thumbnail_url,
},
});
}
});
/**
* Delete Image - PENDING
* Pre-emptively remove the image from the gallery
*/
builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg;
if (imageType === 'results') {
resultsAdapter.removeOne(state, imageName);
}
});
}, },
}); });
@ -120,6 +83,6 @@ export const {
selectTotal: selectResultsTotal, selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results); } = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions; export const { resultUpserted } = resultsSlice.actions;
export default resultsSlice.reducer; export default resultsSlice.reducer;

View File

@ -1,11 +1,14 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; import {
PayloadAction,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
receivedUploadImagesPage, receivedUploadImagesPage,
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
} from 'services/thunks/gallery'; } from 'services/thunks/gallery';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator'; import { dateComparator } from 'common/util/dateComparator';
@ -23,6 +26,7 @@ type AdditionalUploadsState = {
pages: number; pages: number;
isLoading: boolean; isLoading: boolean;
nextPage: number; nextPage: number;
upsertedImageCount: number;
}; };
export const initialUploadsState = export const initialUploadsState =
@ -31,6 +35,7 @@ export const initialUploadsState =
pages: 0, pages: 0,
nextPage: 0, nextPage: 0,
isLoading: false, isLoading: false,
upsertedImageCount: 0,
}); });
export type UploadsState = typeof initialUploadsState; export type UploadsState = typeof initialUploadsState;
@ -39,7 +44,10 @@ const uploadsSlice = createSlice({
name: 'uploads', name: 'uploads',
initialState: initialUploadsState, initialState: initialUploadsState,
reducers: { reducers: {
uploadAdded: uploadsAdapter.upsertOne, uploadUpserted: (state, action: PayloadAction<UploadsImageDTO>) => {
uploadsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
/** /**
@ -65,36 +73,6 @@ const uploadsSlice = createSlice({
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false; state.isLoading = false;
}); });
/**
* Image URLs Received - FULFILLED
*/
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_type, image_url, thumbnail_url } =
action.payload;
if (image_type === 'uploads') {
uploadsAdapter.updateOne(state, {
id: image_name,
changes: {
image_url: image_url,
thumbnail_url: thumbnail_url,
},
});
}
});
/**
* Delete Image - pending
* Pre-emptively remove the image from the gallery
*/
builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg;
if (imageType === 'uploads') {
uploadsAdapter.removeOne(state, imageName);
}
});
}, },
}); });
@ -106,6 +84,6 @@ export const {
selectTotal: selectUploadsTotal, selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads); } = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions; export const { uploadUpserted } = uploadsSlice.actions;
export default uploadsSlice.reducer; export default uploadsSlice.reducer;

View File

@ -7,6 +7,7 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
@ -97,6 +98,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'control' && template.type === 'control') {
return (
<ControlInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') { if (type === 'model' && template.type === 'model') {
return ( return (
<ModelInputFieldComponent <ModelInputFieldComponent

View File

@ -0,0 +1,16 @@
import {
ControlInputFieldTemplate,
ControlInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const ControlInputFieldComponent = (
props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(ControlInputFieldComponent);

View File

@ -4,6 +4,7 @@ export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const FIELD_TYPE_MAP: Record<string, FieldType> = { export const FIELD_TYPE_MAP: Record<string, FieldType> = {
integer: 'integer', integer: 'integer',
float: 'float',
number: 'float', number: 'float',
string: 'string', string: 'string',
boolean: 'boolean', boolean: 'boolean',
@ -15,6 +16,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
array: 'array', array: 'array',
item: 'item', item: 'item',
ColorField: 'color', ColorField: 'color',
ControlField: 'control',
control: 'control',
}; };
const COLOR_TOKEN_VALUE = 500; const COLOR_TOKEN_VALUE = 500;
@ -22,6 +25,9 @@ const COLOR_TOKEN_VALUE = 500;
const getColorTokenCssVariable = (color: string) => const getColorTokenCssVariable = (color: string) =>
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`; `var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
// @ts-ignore
// @ts-ignore
// @ts-ignore
export const FIELDS: Record<FieldType, FieldUIConfig> = { export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: { integer: {
color: 'red', color: 'red',
@ -71,6 +77,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Conditioning', title: 'Conditioning',
description: 'Conditioning may be passed between nodes.', description: 'Conditioning may be passed between nodes.',
}, },
control: {
color: 'cyan',
colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left
title: 'Control',
description: 'Control info passed between nodes.',
},
model: { model: {
color: 'teal', color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'), colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -61,6 +61,7 @@ export type FieldType =
| 'image' | 'image'
| 'latents' | 'latents'
| 'conditioning' | 'conditioning'
| 'control'
| 'model' | 'model'
| 'array' | 'array'
| 'item' | 'item'
@ -82,6 +83,7 @@ export type InputFieldValue =
| ImageInputFieldValue | ImageInputFieldValue
| LatentsInputFieldValue | LatentsInputFieldValue
| ConditioningInputFieldValue | ConditioningInputFieldValue
| ControlInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| ModelInputFieldValue | ModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
@ -102,6 +104,7 @@ export type InputFieldTemplate =
| ImageInputFieldTemplate | ImageInputFieldTemplate
| LatentsInputFieldTemplate | LatentsInputFieldTemplate
| ConditioningInputFieldTemplate | ConditioningInputFieldTemplate
| ControlInputFieldTemplate
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| ArrayInputFieldTemplate | ArrayInputFieldTemplate
@ -177,6 +180,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
export type ConditioningInputFieldValue = FieldValueBase & { export type ConditioningInputFieldValue = FieldValueBase & {
type: 'conditioning'; type: 'conditioning';
value?: string;
};
export type ControlInputFieldValue = FieldValueBase & {
type: 'control';
value?: undefined; value?: undefined;
}; };
@ -262,6 +270,11 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
type: 'conditioning'; type: 'conditioning';
}; };
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'control';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & { export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number; default: string | number;
type: 'enum'; type: 'enum';

View File

@ -10,6 +10,7 @@ import {
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
LatentsInputFieldTemplate, LatentsInputFieldTemplate,
ConditioningInputFieldTemplate, ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
ModelInputFieldTemplate, ModelInputFieldTemplate,
ArrayInputFieldTemplate, ArrayInputFieldTemplate,
@ -215,6 +216,21 @@ const buildConditioningInputFieldTemplate = ({
return template; return template;
}; };
const buildControlInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlInputFieldTemplate => {
const template: ControlInputFieldTemplate = {
...baseField,
type: 'control',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({ const buildEnumInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -286,9 +302,20 @@ export const getFieldType = (
if (typeHints && name in typeHints) { if (typeHints && name in typeHints) {
rawFieldType = typeHints[name]; rawFieldType = typeHints[name];
} else if (!schemaObject.type) { } else if (!schemaObject.type) {
rawFieldType = refObjectToFieldType( // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject if (schemaObject.allOf) {
); rawFieldType = refObjectToFieldType(
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.anyOf) {
rawFieldType = refObjectToFieldType(
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
rawFieldType = refObjectToFieldType(
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
}
} else if (schemaObject.enum) { } else if (schemaObject.enum) {
rawFieldType = 'enum'; rawFieldType = 'enum';
} else if (schemaObject.type) { } else if (schemaObject.type) {
@ -331,6 +358,9 @@ export const buildInputFieldTemplate = (
if (['conditioning'].includes(fieldType)) { if (['conditioning'].includes(fieldType)) {
return buildConditioningInputFieldTemplate({ schemaObject, baseField }); return buildConditioningInputFieldTemplate({ schemaObject, baseField });
} }
if (['control'].includes(fieldType)) {
return buildControlInputFieldTemplate({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) { if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField }); return buildModelInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -52,6 +52,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'control') {
fieldValue.value = undefined;
}
if (template.type === 'model') { if (template.type === 'model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }

View File

@ -11,7 +11,7 @@ import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
const POSITIVE_CONDITIONING = 'positive_conditioning'; const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning';
const TEXT_TO_LATENTS = 'text_to_latents'; const TEXT_TO_LATENTS = 'text_to_latents';
const LATENTS_TO_IMAGE = 'latnets_to_image'; const LATENTS_TO_IMAGE = 'latents_to_image';
/** /**
* Builds the Text to Image tab graph. * Builds the Text to Image tab graph.

View File

@ -13,7 +13,9 @@ import {
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const invocationDenylist = ['Graph']; const RESERVED_FIELD_NAMES = ['id', 'type', 'meta'];
const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now
@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
(inputsAccumulator, property, propertyName) => { (inputsAccumulator, property, propertyName) => {
if ( if (
// `type` and `id` are not valid inputs/outputs // `type` and `id` are not valid inputs/outputs
!['type', 'id'].includes(propertyName) && !RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property) isSchemaObject(property)
) { ) {
const field: InputFieldTemplate | undefined = const field: InputFieldTemplate | undefined =

View File

@ -21,8 +21,8 @@ export const useParameters = () => {
* Sets prompt with toast * Sets prompt with toast
*/ */
const recallPrompt = useCallback( const recallPrompt = useCallback(
(prompt: unknown) => { (prompt: unknown, negativePrompt?: unknown) => {
if (!isString(prompt)) { if (!isString(prompt) || !isString(negativePrompt)) {
toaster({ toaster({
title: t('toast.promptNotSet'), title: t('toast.promptNotSet'),
description: t('toast.promptNotSetDesc'), description: t('toast.promptNotSetDesc'),
@ -33,7 +33,7 @@ export const useParameters = () => {
return; return;
} }
setBothPrompts(prompt); setBothPrompts(prompt, negativePrompt);
toaster({ toaster({
title: t('toast.promptSet'), title: t('toast.promptSet'),
status: 'info', status: 'info',
@ -112,12 +112,13 @@ export const useParameters = () => {
const recallAllParameters = useCallback( const recallAllParameters = useCallback(
(image: ImageDTO | undefined) => { (image: ImageDTO | undefined) => {
const type = image?.metadata?.type; const type = image?.metadata?.type;
if (['txt2img', 'img2img', 'inpaint'].includes(String(type))) { // not sure what this list should be
if (['t2l', 'l2l', 'inpaint'].includes(String(type))) {
dispatch(allParametersSet(image)); dispatch(allParametersSet(image));
if (image?.metadata?.type === 'img2img') { if (image?.metadata?.type === 'l2l') {
dispatch(setActiveTab('img2img')); dispatch(setActiveTab('img2img'));
} else if (image?.metadata?.type === 'txt2img') { } else if (image?.metadata?.type === 't2l') {
dispatch(setActiveTab('txt2img')); dispatch(setActiveTab('txt2img'));
} }

View File

@ -12,15 +12,8 @@ const useSetBothPrompts = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
return useCallback( return useCallback(
(inputPrompt: InvokeAI.Prompt) => { (inputPrompt: InvokeAI.Prompt, negativePrompt: InvokeAI.Prompt) => {
const promptString = dispatch(setPositivePrompt(inputPrompt));
typeof inputPrompt === 'string'
? inputPrompt
: promptToString(inputPrompt);
const [prompt, negativePrompt] = getPromptAndNegative(promptString);
dispatch(setPositivePrompt(prompt));
dispatch(setNegativePrompt(negativePrompt)); dispatch(setNegativePrompt(negativePrompt));
}, },
[dispatch] [dispatch]

View File

@ -52,7 +52,7 @@ export const initialGenerationState: GenerationState = {
perlin: 0, perlin: 0,
positivePrompt: '', positivePrompt: '',
negativePrompt: '', negativePrompt: '',
scheduler: 'lms', scheduler: 'euler',
seamBlur: 16, seamBlur: 16,
seamSize: 96, seamSize: 96,
seamSteps: 30, seamSteps: 30,

View File

@ -7,19 +7,29 @@ export const setAllParametersReducer = (
state: Draft<GenerationState>, state: Draft<GenerationState>,
action: PayloadAction<ImageDTO | undefined> action: PayloadAction<ImageDTO | undefined>
) => { ) => {
const node = action.payload?.metadata.invokeai?.node; const metadata = action.payload?.metadata;
if (!node) { if (!metadata) {
return; return;
} }
// not sure what this list should be
if ( if (
node.type === 'txt2img' || metadata.type === 't2l' ||
node.type === 'img2img' || metadata.type === 'l2l' ||
node.type === 'inpaint' metadata.type === 'inpaint'
) { ) {
const { cfg_scale, height, model, prompt, scheduler, seed, steps, width } = const {
node; cfg_scale,
height,
model,
positive_conditioning,
negative_conditioning,
scheduler,
seed,
steps,
width,
} = metadata;
if (cfg_scale !== undefined) { if (cfg_scale !== undefined) {
state.cfgScale = Number(cfg_scale); state.cfgScale = Number(cfg_scale);
@ -30,8 +40,11 @@ export const setAllParametersReducer = (
if (model !== undefined) { if (model !== undefined) {
state.model = String(model); state.model = String(model);
} }
if (prompt !== undefined) { if (positive_conditioning !== undefined) {
state.positivePrompt = String(prompt); state.positivePrompt = String(positive_conditioning);
}
if (negative_conditioning !== undefined) {
state.negativePrompt = String(negative_conditioning);
} }
if (scheduler !== undefined) { if (scheduler !== undefined) {
const schedulerString = String(scheduler); const schedulerString = String(scheduler);
@ -51,8 +64,8 @@ export const setAllParametersReducer = (
} }
} }
if (node.type === 'img2img') { if (metadata.type === 'l2l') {
const { fit, image } = node as ImageToImageInvocation; const { fit, image } = metadata as ImageToImageInvocation;
if (fit !== undefined) { if (fit !== undefined) {
state.shouldFitToWidthHeight = Boolean(fit); state.shouldFitToWidthHeight = Boolean(fit);

View File

@ -0,0 +1,3 @@
import { createAction } from '@reduxjs/toolkit';
export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke');

View File

@ -0,0 +1,62 @@
// TODO: split system slice inot this
// import type { PayloadAction } from '@reduxjs/toolkit';
// import { createSlice } from '@reduxjs/toolkit';
// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions';
// export type SessionState = {
// /**
// * The current socket session id
// */
// sessionId: string;
// /**
// * Whether the current session is a canvas session. Needed to manage the staging area.
// */
// isCanvasSession: boolean;
// /**
// * When a session is canceled, its ID is stored here until a new session is created.
// */
// canceledSessionId: string;
// };
// export const initialSessionState: SessionState = {
// sessionId: '',
// isCanvasSession: false,
// canceledSessionId: '',
// };
// export const sessionSlice = createSlice({
// name: 'session',
// initialState: initialSessionState,
// reducers: {
// sessionIdChanged: (state, action: PayloadAction<string>) => {
// state.sessionId = action.payload;
// },
// isCanvasSessionChanged: (state, action: PayloadAction<boolean>) => {
// state.isCanvasSession = action.payload;
// },
// },
// extraReducers: (builder) => {
// /**
// * Socket Subscribed
// */
// builder.addCase(socketSubscribed, (state, action) => {
// state.sessionId = action.payload.sessionId;
// state.canceledSessionId = '';
// });
// /**
// * Socket Unsubscribed
// */
// builder.addCase(socketUnsubscribed, (state) => {
// state.sessionId = '';
// });
// },
// });
// export const { sessionIdChanged, isCanvasSessionChanged } =
// sessionSlice.actions;
// export default sessionSlice.reducer;
export default {};

View File

@ -1,5 +1,5 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import type { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, isAnyOf } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import { import {
@ -16,7 +16,11 @@ import {
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { sessionCanceled, sessionInvoked } from 'services/thunks/session'; import {
sessionCanceled,
sessionCreated,
sessionInvoked,
} from 'services/thunks/session';
import { receivedModels } from 'services/thunks/model'; import { receivedModels } from 'services/thunks/model';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
@ -215,6 +219,9 @@ export const systemSlice = createSlice({
languageChanged: (state, action: PayloadAction<keyof typeof LANGUAGES>) => { languageChanged: (state, action: PayloadAction<keyof typeof LANGUAGES>) => {
state.language = action.payload; state.language = action.payload;
}, },
progressImageSet(state, action: PayloadAction<ProgressImage | null>) {
state.progressImage = action.payload;
},
}, },
extraReducers(builder) { extraReducers(builder) {
/** /**
@ -305,7 +312,6 @@ export const systemSlice = createSlice({
state.currentStep = 0; state.currentStep = 0;
state.totalSteps = 0; state.totalSteps = 0;
state.statusTranslationKey = 'common.statusProcessingComplete'; state.statusTranslationKey = 'common.statusProcessingComplete';
state.progressImage = null;
if (state.canceledSession === data.graph_execution_state_id) { if (state.canceledSession === data.graph_execution_state_id) {
state.isProcessing = false; state.isProcessing = false;
@ -343,15 +349,8 @@ export const systemSlice = createSlice({
state.statusTranslationKey = 'common.statusPreparing'; state.statusTranslationKey = 'common.statusPreparing';
}); });
builder.addCase(sessionInvoked.rejected, (state, action) => {
const error = action.payload as string | undefined;
state.toastQueue.push(
makeToast({ title: error || t('toast.serverError'), status: 'error' })
);
});
/** /**
* Session Canceled * Session Canceled - FULFILLED
*/ */
builder.addCase(sessionCanceled.fulfilled, (state, action) => { builder.addCase(sessionCanceled.fulfilled, (state, action) => {
state.canceledSession = action.meta.arg.sessionId; state.canceledSession = action.meta.arg.sessionId;
@ -414,6 +413,26 @@ export const systemSlice = createSlice({
builder.addCase(imageUploaded.fulfilled, (state) => { builder.addCase(imageUploaded.fulfilled, (state) => {
state.isUploading = false; state.isUploading = false;
}); });
// *** Matchers - must be after all cases ***
/**
* Session Invoked - REJECTED
* Session Created - REJECTED
*/
builder.addMatcher(isAnySessionRejected, (state, action) => {
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusConnected';
state.progressImage = null;
state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
);
});
}, },
}); });
@ -438,6 +457,12 @@ export const {
isPersistedChanged, isPersistedChanged,
shouldAntialiasProgressImageChanged, shouldAntialiasProgressImageChanged,
languageChanged, languageChanged,
progressImageSet,
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;
const isAnySessionRejected = isAnyOf(
sessionCreated.rejected,
sessionInvoked.rejected
);

View File

@ -7,7 +7,6 @@ export { OpenAPI } from './core/OpenAPI';
export type { OpenAPIConfig } from './core/OpenAPI'; export type { OpenAPIConfig } from './core/OpenAPI';
export type { AddInvocation } from './models/AddInvocation'; export type { AddInvocation } from './models/AddInvocation';
export type { BlurInvocation } from './models/BlurInvocation';
export type { Body_upload_image } from './models/Body_upload_image'; export type { Body_upload_image } from './models/Body_upload_image';
export type { CkptModelInfo } from './models/CkptModelInfo'; export type { CkptModelInfo } from './models/CkptModelInfo';
export type { CollectInvocation } from './models/CollectInvocation'; export type { CollectInvocation } from './models/CollectInvocation';
@ -17,7 +16,6 @@ export type { CompelInvocation } from './models/CompelInvocation';
export type { CompelOutput } from './models/CompelOutput'; export type { CompelOutput } from './models/CompelOutput';
export type { ConditioningField } from './models/ConditioningField'; export type { ConditioningField } from './models/ConditioningField';
export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CropImageInvocation } from './models/CropImageInvocation';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation'; export type { DivideInvocation } from './models/DivideInvocation';
@ -28,11 +26,20 @@ export type { GraphExecutionState } from './models/GraphExecutionState';
export type { GraphInvocation } from './models/GraphInvocation'; export type { GraphInvocation } from './models/GraphInvocation';
export type { GraphInvocationOutput } from './models/GraphInvocationOutput'; export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
export type { HTTPValidationError } from './models/HTTPValidationError'; export type { HTTPValidationError } from './models/HTTPValidationError';
export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
export type { ImageCategory } from './models/ImageCategory'; export type { ImageCategory } from './models/ImageCategory';
export type { ImageChannelInvocation } from './models/ImageChannelInvocation';
export type { ImageConvertInvocation } from './models/ImageConvertInvocation';
export type { ImageCropInvocation } from './models/ImageCropInvocation';
export type { ImageDTO } from './models/ImageDTO'; export type { ImageDTO } from './models/ImageDTO';
export type { ImageField } from './models/ImageField'; export type { ImageField } from './models/ImageField';
export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation';
export type { ImageLerpInvocation } from './models/ImageLerpInvocation';
export type { ImageMetadata } from './models/ImageMetadata'; export type { ImageMetadata } from './models/ImageMetadata';
export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
export type { ImageOutput } from './models/ImageOutput'; export type { ImageOutput } from './models/ImageOutput';
export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
export type { ImageRecordChanges } from './models/ImageRecordChanges';
export type { ImageToImageInvocation } from './models/ImageToImageInvocation'; export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation'; export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
export type { ImageType } from './models/ImageType'; export type { ImageType } from './models/ImageType';
@ -43,14 +50,12 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput'; export type { IntOutput } from './models/IntOutput';
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField'; export type { LatentsField } from './models/LatentsField';
export type { LatentsOutput } from './models/LatentsOutput'; export type { LatentsOutput } from './models/LatentsOutput';
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation'; export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation'; export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
export type { LerpInvocation } from './models/LerpInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput'; export type { MaskOutput } from './models/MaskOutput';
@ -61,7 +66,6 @@ export type { NoiseOutput } from './models/NoiseOutput';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_'; export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PasteImageInvocation } from './models/PasteImageInvocation';
export type { PromptOutput } from './models/PromptOutput'; export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation';

View File

@ -10,6 +10,10 @@ export type AddInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'add'; type?: 'add';
/** /**
* The first number * The first number

View File

@ -0,0 +1,29 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Canny edge detection for ControlNet
*/
export type CannyImageProcessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'canny_image_processor';
/**
* image to process
*/
image?: ImageField;
/**
* low threshold of Canny pixel gradient
*/
low_threshold?: number;
/**
* high threshold of Canny pixel gradient
*/
high_threshold?: number;
};

View File

@ -10,6 +10,10 @@ export type CollectInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'collect'; type?: 'collect';
/** /**
* The item to collect (all inputs must be of the same type) * The item to collect (all inputs must be of the same type)

View File

@ -10,6 +10,10 @@ export type CompelInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'compel'; type?: 'compel';
/** /**
* Prompt * Prompt

View File

@ -0,0 +1,41 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies content shuffle processing to image
*/
export type ContentShuffleImageProcessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'content_shuffle_image_processor';
/**
* image to process
*/
image?: ImageField;
/**
* pixel resolution for edge detection
*/
detect_resolution?: number;
/**
* pixel resolution for output image
*/
image_resolution?: number;
/**
* content shuffle h parameter
*/
'h'?: number;
/**
* content shuffle w parameter
*/
'w'?: number;
/**
* cont
*/
'f'?: number;
};

View File

@ -0,0 +1,29 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
export type ControlField = {
/**
* processed image
*/
image: ImageField;
/**
* control model used
*/
control_model: string;
/**
* weight given to controlnet
*/
control_weight: number;
/**
* % of total steps at which controlnet is first applied
*/
begin_step_percent: number;
/**
* % of total steps at which controlnet is last applied
*/
end_step_percent: number;
};

View File

@ -0,0 +1,37 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Collects ControlNet info to pass to other nodes
*/
export type ControlNetInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'controlnet';
/**
* image to process
*/
image?: ImageField;
/**
* control model used
*/
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace';
/**
* weight given to controlnet
*/
control_weight?: number;
/**
* % of total steps at which controlnet is first applied
*/
begin_step_percent?: number;
/**
* % of total steps at which controlnet is last applied
*/
end_step_percent?: number;
};

View File

@ -0,0 +1,17 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ControlField } from './ControlField';
/**
* node output for ControlNet info
*/
export type ControlOutput = {
type?: 'control_output';
/**
* The control info dict
*/
control?: ControlField;
};

View File

@ -12,6 +12,10 @@ export type CvInpaintInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'cv_inpaint'; type?: 'cv_inpaint';
/** /**
* The image to inpaint * The image to inpaint

View File

@ -10,6 +10,10 @@ export type DivideInvocation = {
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'div'; type?: 'div';
/** /**
* The first number * The first number

View File

@ -3,31 +3,34 @@
/* eslint-disable */ /* eslint-disable */
import type { AddInvocation } from './AddInvocation'; import type { AddInvocation } from './AddInvocation';
import type { BlurInvocation } from './BlurInvocation';
import type { CollectInvocation } from './CollectInvocation'; import type { CollectInvocation } from './CollectInvocation';
import type { CompelInvocation } from './CompelInvocation'; import type { CompelInvocation } from './CompelInvocation';
import type { CropImageInvocation } from './CropImageInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation'; import type { DivideInvocation } from './DivideInvocation';
import type { Edge } from './Edge'; import type { Edge } from './Edge';
import type { GraphInvocation } from './GraphInvocation'; import type { GraphInvocation } from './GraphInvocation';
import type { ImageBlurInvocation } from './ImageBlurInvocation';
import type { ImageChannelInvocation } from './ImageChannelInvocation';
import type { ImageConvertInvocation } from './ImageConvertInvocation';
import type { ImageCropInvocation } from './ImageCropInvocation';
import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation';
import type { ImageLerpInvocation } from './ImageLerpInvocation';
import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation';
import type { ImagePasteInvocation } from './ImagePasteInvocation';
import type { ImageToImageInvocation } from './ImageToImageInvocation'; import type { ImageToImageInvocation } from './ImageToImageInvocation';
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation'; import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
import type { InfillColorInvocation } from './InfillColorInvocation'; import type { InfillColorInvocation } from './InfillColorInvocation';
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation'; import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
import type { InfillTileInvocation } from './InfillTileInvocation'; import type { InfillTileInvocation } from './InfillTileInvocation';
import type { InpaintInvocation } from './InpaintInvocation'; import type { InpaintInvocation } from './InpaintInvocation';
import type { InverseLerpInvocation } from './InverseLerpInvocation';
import type { IterateInvocation } from './IterateInvocation'; import type { IterateInvocation } from './IterateInvocation';
import type { LatentsToImageInvocation } from './LatentsToImageInvocation'; import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation'; import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
import type { LerpInvocation } from './LerpInvocation';
import type { LoadImageInvocation } from './LoadImageInvocation'; import type { LoadImageInvocation } from './LoadImageInvocation';
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation'; import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
import type { MultiplyInvocation } from './MultiplyInvocation'; import type { MultiplyInvocation } from './MultiplyInvocation';
import type { NoiseInvocation } from './NoiseInvocation'; import type { NoiseInvocation } from './NoiseInvocation';
import type { ParamIntInvocation } from './ParamIntInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation';
import type { PasteImageInvocation } from './PasteImageInvocation';
import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation';
import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation';
import type { RangeInvocation } from './RangeInvocation'; import type { RangeInvocation } from './RangeInvocation';
@ -49,7 +52,7 @@ export type Graph = {
/** /**
* The nodes in this graph * The nodes in this graph
*/ */
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>; nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/** /**
* The connections between nodes and their fields in this graph * The connections between nodes and their fields in this graph
*/ */

View File

@ -5,14 +5,17 @@
import type { Graph } from './Graph'; import type { Graph } from './Graph';
/** /**
* A node to process inputs and produce outputs. * Execute a graph
* May use dependency injection in __init__ to receive providers.
*/ */
export type GraphInvocation = { export type GraphInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'graph'; type?: 'graph';
/** /**
* The graph to run * The graph to run

View File

@ -0,0 +1,33 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageprocessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'hed_image_processor';
/**
* image to process
*/
image?: ImageField;
/**
* pixel resolution for edge detection
*/
detect_resolution?: number;
/**
* pixel resolution for output image
*/
image_resolution?: number;
/**
* whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/** /**
* Blurs an image * Blurs an image
*/ */
export type BlurInvocation = { export type ImageBlurInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
type?: 'blur'; /**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'img_blur';
/** /**
* The image to blur * The image to blur
*/ */

View File

@ -5,4 +5,4 @@
/** /**
* The category of an image. Use ImageCategory.OTHER for non-default categories. * The category of an image. Use ImageCategory.OTHER for non-default categories.
*/ */
export type ImageCategory = 'general' | 'control' | 'other'; export type ImageCategory = 'general' | 'control' | 'mask' | 'other';

View File

@ -0,0 +1,29 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Gets a channel from an image.
*/
export type ImageChannelInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'img_chan';
/**
* The image to get the channel from
*/
image?: ImageField;
/**
* The channel to get
*/
channel?: 'A' | 'R' | 'G' | 'B';
};

View File

@ -0,0 +1,29 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Converts an image to a different mode.
*/
export type ImageConvertInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'img_conv';
/**
* The image to convert
*/
image?: ImageField;
/**
* The mode to convert to
*/
mode?: 'L' | 'RGB' | 'RGBA' | 'CMYK' | 'YCbCr' | 'LAB' | 'HSV' | 'I' | 'F';
};

View File

@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/** /**
* Crops an image to a specified box. The box can be outside of the image. * Crops an image to a specified box. The box can be outside of the image.
*/ */
export type CropImageInvocation = { export type ImageCropInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
type?: 'crop'; /**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'img_crop';
/** /**
* The image to crop * The image to crop
*/ */

View File

@ -50,6 +50,10 @@ export type ImageDTO = {
* The deleted timestamp of the image. * The deleted timestamp of the image.
*/ */
deleted_at?: string; deleted_at?: string;
/**
* Whether this is an intermediate image.
*/
is_intermediate: boolean;
/** /**
* The session ID that generated this image, if it is a generated image. * The session ID that generated this image, if it is a generated image.
*/ */

View File

@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/** /**
* Inverse linear interpolation of all pixels of an image * Inverse linear interpolation of all pixels of an image
*/ */
export type InverseLerpInvocation = { export type ImageInverseLerpInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
type?: 'ilerp'; /**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'img_ilerp';
/** /**
* The image to lerp * The image to lerp
*/ */

Some files were not shown because too many files have changed in this diff Show More