mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into release/make-web-dist-startable
This commit is contained in:
commit
d37b08a7dd
@ -96,6 +96,7 @@ class UIConfig(TypedDict, total=False):
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
|
@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
428
invokeai/app/invocations/controlnet_image_processors.py
Normal file
428
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@ -0,0 +1,428 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType, ImageCategory
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
LineartDetector,
|
||||
LineartAnimeDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
ContentShuffleDetector,
|
||||
ZoeDetector,
|
||||
MediapipeFaceDetector,
|
||||
)
|
||||
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="processed image")
|
||||
control_model: Optional[str] = Field(default=None, description="control model used")
|
||||
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="% of total steps at which controlnet is first applied")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="% of total steps at which controlnet is last applied")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info dict")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="image to process")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="% of total steps at which controlnet is first applied")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="% of total steps at which controlnet is last applied")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: move image processors to separate file (image_analysis.py
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="image to process")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# FIXME: what happened to image metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_type=image_dto.image_type,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width = image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height = image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
|
||||
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies HED edge detection to image"""
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
coarse: bool = Field(default=False, description="whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
coarse=self.coarse)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies NormalBae processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies MLSD processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies PIDI processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies content shuffle processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
|
||||
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
|
||||
f: Union[int | None] = Field(default=256, ge=0, description="cont")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
@ -4,7 +4,9 @@ from functools import partial
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel
|
||||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -58,6 +60,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
@ -78,17 +83,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get(
|
||||
self.control_image.image_type, self.control_image.image_name
|
||||
)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
control_model = None
|
||||
else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
txt2img = Txt2Img(model, control_model=control_model)
|
||||
outputs = txt2img.generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
control_image=control_image,
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
exclude={"prompt", "control_image" }
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
|
@ -1,8 +1,11 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional, Union
|
||||
import einops
|
||||
from typing import Literal, Optional, Union, List
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
|
||||
@ -11,14 +14,18 @@ from invokeai.app.models.image import ImageCategory
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from .controlnet_image_processors import ControlField
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_file_storage import ImageType
|
||||
@ -28,7 +35,7 @@ from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import DiffusionPipeline, ControlNetModel
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
@ -84,13 +91,13 @@ SAMPLER_NAME_VALUES = Literal[
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@ -169,6 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
@ -179,7 +187,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -238,6 +247,81 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
def prep_control_data(self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
if control_input is None:
|
||||
# print("control input is None")
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
# print("control input is empty list")
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
# print("control input is ControlField")
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
# print("control input is list[ControlField]")
|
||||
control_list = control_input
|
||||
else:
|
||||
# print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
if (control_list is None):
|
||||
control_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(control_name,
|
||||
subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_type,
|
||||
control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
@ -252,14 +336,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
print("type of control input: ", type(self.control))
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
@ -285,7 +374,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -304,6 +394,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
print("type of control input: ", type(self.control))
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
@ -318,6 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
@ -362,8 +458,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
# what happened to metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# new (post Image service refactor) way of using services to save image
|
||||
# and gnenerate unique image_name
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_type=ImageType.RESULT,
|
||||
@ -414,6 +516,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
@ -444,6 +547,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
@ -468,6 +572,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
@ -488,6 +595,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, latents)
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
|
@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
from .math import IntOutput, FloatOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
from typing import Optional, Union, List, get_args
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
else:
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
# If the type is a List
|
||||
if get_origin(t) is list:
|
||||
return True
|
||||
|
||||
# If the type is a Union
|
||||
elif t_args:
|
||||
# Check if any of the types in the Union is a List
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
if not issubclass(from_type, to_type):
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@ -694,7 +724,11 @@ class Graph(BaseModel):
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
|
@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
|
||||
# get conditioning from prompt via Compel package
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision)
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
@ -276,7 +281,7 @@ class Generator:
|
||||
precision: str
|
||||
model: DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
||||
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
|
@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
@ -13,8 +17,13 @@ from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
def __init__(self, model, precision,
|
||||
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||
**kwargs):
|
||||
self.control_model = control_model
|
||||
if isinstance(self.control_model, list):
|
||||
self.control_model = MultiControlNetModel(self.control_model)
|
||||
super().__init__(model, precision, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
@ -42,9 +51,12 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
control_image = kwargs.get("control_image", None)
|
||||
do_classifier_free_guidance = cfg_scale > 1.0
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.control_model = self.control_model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
@ -61,6 +73,37 @@ class Txt2Img(Generator):
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
if control_image is not None:
|
||||
if isinstance(self.control_model, ControlNetModel):
|
||||
control_image = pipeline.prepare_control_image(
|
||||
image=control_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
elif isinstance(self.control_model, MultiControlNetModel):
|
||||
images = []
|
||||
for image_ in control_image:
|
||||
image_ = self.model.prepare_control_image(
|
||||
image=image_,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
images.append(image_)
|
||||
control_image = images
|
||||
kwargs["control_image"] = control_image
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
@ -68,6 +111,7 @@ class Txt2Img(Generator):
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -2,23 +2,29 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from compel import EmbeddingsProvider
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: float = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
# FIXME: can't currently register control module
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward, is_running_diffusers=True
|
||||
@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
**kwargs,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**kwargs,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(
|
||||
@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# default is no controlnet, so set controlnet processing output to None
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
if conditioning_data.guidance_scale > 1.0:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||
else:
|
||||
latent_control_input = latent_model_input
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
# print("controlnet", i, "==>", type(control_datum))
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
# print("running controlnet", i, "for step", step_index)
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=control_datum.weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
latent_model_input,
|
||||
@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
t,
|
||||
text_embeddings,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(
|
||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
|
||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
def img2img_from_embeddings(
|
||||
@ -728,7 +803,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||
@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
||||
|
||||
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents,
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
batch_size=1,
|
||||
num_images_per_prompt=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
):
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
image = images
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
return image
|
||||
|
@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x, sigma, unconditioning, conditioning
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
(
|
||||
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
x_twice = torch.cat([x] * 2)
|
||||
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
|
||||
else:
|
||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
).chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
unconditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
|
||||
sigma,
|
||||
conditioning,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning
|
||||
x, sigma, edited_conditioning, **kwargs,
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { ImageField, LatentsField, ConditioningField } from 'services/api';
|
||||
import {
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
ControlField,
|
||||
} from 'services/api';
|
||||
|
||||
const OBJECT_TYPESTRING = '[object Object]';
|
||||
const STRING_TYPESTRING = '[object String]';
|
||||
@ -98,6 +103,24 @@ const parseConditioningField = (
|
||||
};
|
||||
};
|
||||
|
||||
const parseControlField = (controlField: unknown): ControlField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(controlField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A ControlField must have a `control`
|
||||
if (!('control' in controlField)) {
|
||||
return;
|
||||
}
|
||||
// console.log(typeof controlField.control);
|
||||
|
||||
// Build a valid ControlField
|
||||
return {
|
||||
control: controlField.control,
|
||||
};
|
||||
};
|
||||
|
||||
type NodeMetadata = {
|
||||
[key: string]:
|
||||
| string
|
||||
@ -105,7 +128,8 @@ type NodeMetadata = {
|
||||
| boolean
|
||||
| ImageField
|
||||
| LatentsField
|
||||
| ConditioningField;
|
||||
| ConditioningField
|
||||
| ControlField;
|
||||
};
|
||||
|
||||
type InvokeAIMetadata = {
|
||||
@ -131,7 +155,7 @@ export const parseNodeMetadata = (
|
||||
return;
|
||||
}
|
||||
|
||||
// the only valid object types are ImageField, LatentsField and ConditioningField
|
||||
// the only valid object types are ImageField, LatentsField, ConditioningField, ControlField
|
||||
if (isObject(nodeItem)) {
|
||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||
const imageField = parseImageField(nodeItem);
|
||||
@ -156,6 +180,14 @@ export const parseNodeMetadata = (
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('control' in nodeItem) {
|
||||
const controlField = parseControlField(nodeItem);
|
||||
if (controlField) {
|
||||
parsed[nodeKey] = controlField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise we accept any string, number or boolean
|
||||
|
@ -7,6 +7,7 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
@ -97,6 +98,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'control' && template.type === 'control') {
|
||||
return (
|
||||
<ControlInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'model' && template.type === 'model') {
|
||||
return (
|
||||
<ModelInputFieldComponent
|
||||
|
@ -0,0 +1,16 @@
|
||||
import {
|
||||
ControlInputFieldTemplate,
|
||||
ControlInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ControlInputFieldComponent = (
|
||||
props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(ControlInputFieldComponent);
|
@ -4,6 +4,7 @@ export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
|
||||
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
integer: 'integer',
|
||||
float: 'float',
|
||||
number: 'float',
|
||||
string: 'string',
|
||||
boolean: 'boolean',
|
||||
@ -15,6 +16,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
ColorField: 'color',
|
||||
ControlField: 'control',
|
||||
control: 'control',
|
||||
};
|
||||
|
||||
const COLOR_TOKEN_VALUE = 500;
|
||||
@ -22,6 +25,9 @@ const COLOR_TOKEN_VALUE = 500;
|
||||
const getColorTokenCssVariable = (color: string) =>
|
||||
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
|
||||
|
||||
// @ts-ignore
|
||||
// @ts-ignore
|
||||
// @ts-ignore
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
color: 'red',
|
||||
@ -71,6 +77,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Conditioning',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
control: {
|
||||
color: 'cyan',
|
||||
colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left
|
||||
title: 'Control',
|
||||
description: 'Control info passed between nodes.',
|
||||
},
|
||||
model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
|
@ -61,6 +61,7 @@ export type FieldType =
|
||||
| 'image'
|
||||
| 'latents'
|
||||
| 'conditioning'
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
@ -82,6 +83,7 @@ export type InputFieldValue =
|
||||
| ImageInputFieldValue
|
||||
| LatentsInputFieldValue
|
||||
| ConditioningInputFieldValue
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
@ -102,6 +104,7 @@ export type InputFieldTemplate =
|
||||
| ImageInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
@ -177,6 +180,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
|
||||
|
||||
export type ConditioningInputFieldValue = FieldValueBase & {
|
||||
type: 'conditioning';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ControlInputFieldValue = FieldValueBase & {
|
||||
type: 'control';
|
||||
value?: undefined;
|
||||
};
|
||||
|
||||
@ -262,6 +270,11 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'conditioning';
|
||||
};
|
||||
|
||||
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'control';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
|
@ -10,6 +10,7 @@ import {
|
||||
IntegerInputFieldTemplate,
|
||||
LatentsInputFieldTemplate,
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
ArrayInputFieldTemplate,
|
||||
@ -215,6 +216,21 @@ const buildConditioningInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlInputFieldTemplate => {
|
||||
const template: ControlInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'control',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -286,9 +302,20 @@ export const getFieldType = (
|
||||
if (typeHints && name in typeHints) {
|
||||
rawFieldType = typeHints[name];
|
||||
} else if (!schemaObject.type) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
if (schemaObject.allOf) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.anyOf) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.oneOf) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
rawFieldType = 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
@ -331,6 +358,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['conditioning'].includes(fieldType)) {
|
||||
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['control'].includes(fieldType)) {
|
||||
return buildControlInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -52,6 +52,10 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'control') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Canny edge detection for ControlNet
|
||||
*/
|
||||
export type CannyImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'canny_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* low threshold of Canny pixel gradient
|
||||
*/
|
||||
low_threshold?: number;
|
||||
/**
|
||||
* high threshold of Canny pixel gradient
|
||||
*/
|
||||
high_threshold?: number;
|
||||
};
|
||||
|
@ -0,0 +1,41 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies content shuffle processing to image
|
||||
*/
|
||||
export type ContentShuffleImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'content_shuffle_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* content shuffle h parameter
|
||||
*/
|
||||
'h'?: number;
|
||||
/**
|
||||
* content shuffle w parameter
|
||||
*/
|
||||
'w'?: number;
|
||||
/**
|
||||
* cont
|
||||
*/
|
||||
'f'?: number;
|
||||
};
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
export type ControlField = {
|
||||
/**
|
||||
* processed image
|
||||
*/
|
||||
image: ImageField;
|
||||
/**
|
||||
* control model used
|
||||
*/
|
||||
control_model: string;
|
||||
/**
|
||||
* weight given to controlnet
|
||||
*/
|
||||
control_weight: number;
|
||||
/**
|
||||
* % of total steps at which controlnet is first applied
|
||||
*/
|
||||
begin_step_percent: number;
|
||||
/**
|
||||
* % of total steps at which controlnet is last applied
|
||||
*/
|
||||
end_step_percent: number;
|
||||
};
|
||||
|
@ -0,0 +1,37 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Collects ControlNet info to pass to other nodes
|
||||
*/
|
||||
export type ControlNetInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'controlnet';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* control model used
|
||||
*/
|
||||
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace';
|
||||
/**
|
||||
* weight given to controlnet
|
||||
*/
|
||||
control_weight?: number;
|
||||
/**
|
||||
* % of total steps at which controlnet is first applied
|
||||
*/
|
||||
begin_step_percent?: number;
|
||||
/**
|
||||
* % of total steps at which controlnet is last applied
|
||||
*/
|
||||
end_step_percent?: number;
|
||||
};
|
||||
|
@ -0,0 +1,17 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ControlField } from './ControlField';
|
||||
|
||||
/**
|
||||
* node output for ControlNet info
|
||||
*/
|
||||
export type ControlOutput = {
|
||||
type?: 'control_output';
|
||||
/**
|
||||
* The control info dict
|
||||
*/
|
||||
control?: ControlField;
|
||||
};
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies HED edge detection to image
|
||||
*/
|
||||
export type HedImageprocessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'hed_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* whether to use scribble mode
|
||||
*/
|
||||
scribble?: boolean;
|
||||
};
|
||||
|
@ -0,0 +1,21 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Base class for invocations that preprocess images for ControlNet
|
||||
*/
|
||||
export type ImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
};
|
||||
|
@ -49,6 +49,18 @@ export type ImageToImageInvocation = {
|
||||
* The model to use (currently ignored)
|
||||
*/
|
||||
model?: string;
|
||||
/**
|
||||
* Whether or not to produce progress images during generation
|
||||
*/
|
||||
progress_images?: boolean;
|
||||
/**
|
||||
* The control model to use
|
||||
*/
|
||||
control_model?: string;
|
||||
/**
|
||||
* The processed control image
|
||||
*/
|
||||
control_image?: ImageField;
|
||||
/**
|
||||
* The input image
|
||||
*/
|
||||
|
@ -50,6 +50,18 @@ export type InpaintInvocation = {
|
||||
* The model to use (currently ignored)
|
||||
*/
|
||||
model?: string;
|
||||
/**
|
||||
* Whether or not to produce progress images during generation
|
||||
*/
|
||||
progress_images?: boolean;
|
||||
/**
|
||||
* The control model to use
|
||||
*/
|
||||
control_model?: string;
|
||||
/**
|
||||
* The processed control image
|
||||
*/
|
||||
control_image?: ImageField;
|
||||
/**
|
||||
* The input image
|
||||
*/
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies line art anime processing to image
|
||||
*/
|
||||
export type LineartAnimeImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'lineart_anime_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
};
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies line art processing to image
|
||||
*/
|
||||
export type LineartImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'lineart_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* whether to use coarse mode
|
||||
*/
|
||||
coarse?: boolean;
|
||||
};
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies Midas depth processing to image
|
||||
*/
|
||||
export type MidasDepthImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'midas_depth_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* Midas parameter a = amult * PI
|
||||
*/
|
||||
a_mult?: number;
|
||||
/**
|
||||
* Midas parameter bg_th
|
||||
*/
|
||||
bg_th?: number;
|
||||
};
|
||||
|
@ -0,0 +1,37 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies MLSD processing to image
|
||||
*/
|
||||
export type MlsdImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'mlsd_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* MLSD parameter thr_v
|
||||
*/
|
||||
thr_v?: number;
|
||||
/**
|
||||
* MLSD parameter thr_d
|
||||
*/
|
||||
thr_d?: number;
|
||||
};
|
||||
|
@ -0,0 +1,29 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies NormalBae processing to image
|
||||
*/
|
||||
export type NormalbaeImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'normalbae_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
};
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies Openpose processing to image
|
||||
*/
|
||||
export type OpenposeImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'openpose_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* whether to use hands and face mode
|
||||
*/
|
||||
hand_and_face?: boolean;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
};
|
||||
|
@ -0,0 +1,37 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies PIDI processing to image
|
||||
*/
|
||||
export type PidiImageProcessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'pidi_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* whether to use safe mode
|
||||
*/
|
||||
safe?: boolean;
|
||||
/**
|
||||
* whether to use scribble mode
|
||||
*/
|
||||
scribble?: boolean;
|
||||
};
|
||||
|
@ -2,6 +2,8 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Generates an image using text2img.
|
||||
*/
|
||||
@ -47,5 +49,17 @@ export type TextToImageInvocation = {
|
||||
* The model to use (currently ignored)
|
||||
*/
|
||||
model?: string;
|
||||
/**
|
||||
* Whether or not to produce progress images during generation
|
||||
*/
|
||||
progress_images?: boolean;
|
||||
/**
|
||||
* The control model to use
|
||||
*/
|
||||
control_model?: string;
|
||||
/**
|
||||
* The processed control image
|
||||
*/
|
||||
control_image?: ImageField;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $CannyImageProcessorInvocation = {
|
||||
description: `Canny edge detection for ControlNet`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
low_threshold: {
|
||||
type: 'number',
|
||||
description: `low threshold of Canny pixel gradient`,
|
||||
},
|
||||
high_threshold: {
|
||||
type: 'number',
|
||||
description: `high threshold of Canny pixel gradient`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,43 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $ContentShuffleImageProcessorInvocation = {
|
||||
description: `Applies content shuffle processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
'h': {
|
||||
type: 'number',
|
||||
description: `content shuffle h parameter`,
|
||||
},
|
||||
'w': {
|
||||
type: 'number',
|
||||
description: `content shuffle w parameter`,
|
||||
},
|
||||
'f': {
|
||||
type: 'number',
|
||||
description: `cont`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,37 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $ControlField = {
|
||||
properties: {
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `processed image`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
isRequired: true,
|
||||
},
|
||||
control_model: {
|
||||
type: 'string',
|
||||
description: `control model used`,
|
||||
isRequired: true,
|
||||
},
|
||||
control_weight: {
|
||||
type: 'number',
|
||||
description: `weight given to controlnet`,
|
||||
isRequired: true,
|
||||
},
|
||||
begin_step_percent: {
|
||||
type: 'number',
|
||||
description: `% of total steps at which controlnet is first applied`,
|
||||
isRequired: true,
|
||||
maximum: 1,
|
||||
},
|
||||
end_step_percent: {
|
||||
type: 'number',
|
||||
description: `% of total steps at which controlnet is last applied`,
|
||||
isRequired: true,
|
||||
maximum: 1,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,41 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $ControlNetInvocation = {
|
||||
description: `Collects ControlNet info to pass to other nodes`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
control_model: {
|
||||
type: 'Enum',
|
||||
},
|
||||
control_weight: {
|
||||
type: 'number',
|
||||
description: `weight given to controlnet`,
|
||||
maximum: 1,
|
||||
},
|
||||
begin_step_percent: {
|
||||
type: 'number',
|
||||
description: `% of total steps at which controlnet is first applied`,
|
||||
maximum: 1,
|
||||
},
|
||||
end_step_percent: {
|
||||
type: 'number',
|
||||
description: `% of total steps at which controlnet is last applied`,
|
||||
maximum: 1,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,28 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $ControlOutput = {
|
||||
description: `node output for ControlNet info`,
|
||||
properties: {
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
control: {
|
||||
type: 'all-of',
|
||||
description: `The control info dict`,
|
||||
contains: [{
|
||||
type: 'ControlField',
|
||||
}],
|
||||
},
|
||||
width: {
|
||||
type: 'number',
|
||||
description: `The width of the noise in pixels`,
|
||||
isRequired: true,
|
||||
},
|
||||
height: {
|
||||
type: 'number',
|
||||
description: `The height of the noise in pixels`,
|
||||
isRequired: true,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,35 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $HedImageprocessorInvocation = {
|
||||
description: `Applies HED edge detection to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
scribble: {
|
||||
type: 'boolean',
|
||||
description: `whether to use scribble mode`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,23 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $ImageProcessorInvocation = {
|
||||
description: `Base class for invocations that preprocess images for ControlNet`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,31 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $LineartAnimeImageProcessorInvocation = {
|
||||
description: `Applies line art anime processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,35 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $LineartImageProcessorInvocation = {
|
||||
description: `Applies line art processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
coarse: {
|
||||
type: 'boolean',
|
||||
description: `whether to use coarse mode`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,31 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $MidasDepthImageProcessorInvocation = {
|
||||
description: `Applies Midas depth processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
a_mult: {
|
||||
type: 'number',
|
||||
description: `Midas parameter a = amult * PI`,
|
||||
},
|
||||
bg_th: {
|
||||
type: 'number',
|
||||
description: `Midas parameter bg_th`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,39 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $MlsdImageProcessorInvocation = {
|
||||
description: `Applies MLSD processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
thr_v: {
|
||||
type: 'number',
|
||||
description: `MLSD parameter thr_v`,
|
||||
},
|
||||
thr_d: {
|
||||
type: 'number',
|
||||
description: `MLSD parameter thr_d`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,31 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $NormalbaeImageProcessorInvocation = {
|
||||
description: `Applies NormalBae processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,35 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $OpenposeImageProcessorInvocation = {
|
||||
description: `Applies Openpose processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
hand_and_face: {
|
||||
type: 'boolean',
|
||||
description: `whether to use hands and face mode`,
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,39 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $PidiImageProcessorInvocation = {
|
||||
description: `Applies PIDI processing to image`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
image: {
|
||||
type: 'all-of',
|
||||
description: `image to process`,
|
||||
contains: [{
|
||||
type: 'ImageField',
|
||||
}],
|
||||
},
|
||||
detect_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for edge detection`,
|
||||
},
|
||||
image_resolution: {
|
||||
type: 'number',
|
||||
description: `pixel resolution for output image`,
|
||||
},
|
||||
safe: {
|
||||
type: 'boolean',
|
||||
description: `whether to use safe mode`,
|
||||
},
|
||||
scribble: {
|
||||
type: 'boolean',
|
||||
description: `whether to use scribble mode`,
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -0,0 +1,16 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $RandomIntInvocation = {
|
||||
description: `Outputs a single random integer.`,
|
||||
properties: {
|
||||
id: {
|
||||
type: 'string',
|
||||
description: `The id of this node. Must be unique among all nodes.`,
|
||||
isRequired: true,
|
||||
},
|
||||
type: {
|
||||
type: 'Enum',
|
||||
},
|
||||
},
|
||||
} as const;
|
@ -39,6 +39,8 @@ dependencies = [
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel~=1.1.5",
|
||||
"controlnet-aux>=0.0.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.16.1",
|
||||
"dnspython==2.2.1",
|
||||
@ -54,6 +56,7 @@ dependencies = [
|
||||
"flaskwebgui==1.0.3",
|
||||
"gfpgan==1.3.8",
|
||||
"huggingface-hub>=0.11.1",
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"npyscreen",
|
||||
"numpy<1.24",
|
||||
"omegaconf",
|
||||
|
54
scripts/controlnet_legacy_txt2img_example.py
Normal file
54
scripts/controlnet_legacy_txt2img_example.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from invokeai.backend.generator import Txt2Img
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
|
||||
|
||||
print("loading 'Girl with a Pearl Earring' image")
|
||||
image = load_image(
|
||||
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
||||
)
|
||||
image.show()
|
||||
|
||||
print("preprocessing image with Canny edge detection")
|
||||
image_np = np.array(image)
|
||||
low_threshold = 100
|
||||
high_threshold = 200
|
||||
canny_np = cv2.Canny(image_np, low_threshold, high_threshold)
|
||||
canny_image = Image.fromarray(canny_np)
|
||||
canny_image.show()
|
||||
|
||||
# using invokeai model management for base model
|
||||
print("loading base model stable-diffusion-1.5")
|
||||
model_config_path = os.getcwd() + "/../configs/models.yaml"
|
||||
model_manager = ModelManager(model_config_path)
|
||||
model = model_manager.get_model('stable-diffusion-1.5')
|
||||
|
||||
print("loading control model lllyasviel/sd-controlnet-canny")
|
||||
canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny",
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
print("testing Txt2Img() constructor with control_model arg")
|
||||
txt2img_canny = Txt2Img(model, control_model=canny_controlnet)
|
||||
|
||||
print("testing Txt2Img.generate() with control_image arg")
|
||||
outputs = txt2img_canny.generate(
|
||||
prompt="old man",
|
||||
control_image=canny_image,
|
||||
control_weight=1.0,
|
||||
seed=0,
|
||||
num_steps=30,
|
||||
precision="float16",
|
||||
)
|
||||
generate_output = next(outputs)
|
||||
out_image = generate_output.image
|
||||
out_image.show()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user