Expose a few more params from TiledStableDiffusionRefineInvocation.

This commit is contained in:
Ryan Dick 2024-06-10 15:38:55 -04:00
parent 9567c6e196
commit 911792f258

View File

@ -55,6 +55,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
negative_conditioning: ConditioningField = InputField( negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection description=FieldDescriptions.negative_cond, input=Input.Connection
) )
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_min_overlap: int = InputField(default=16, gt=0, description="Minimum overlap between tiles.")
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale") cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField( denoising_start: float = InputField(
@ -93,6 +97,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
control_model: ModelIdentifierField = InputField( control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
) )
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale") @field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float: def ge_one(cls, v: list[float] | float) -> list[float] | float:
@ -173,9 +178,9 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
tiles = calc_tiles_min_overlap( tiles = calc_tiles_min_overlap(
image_height=input_image.height, image_height=input_image.height,
image_width=input_image.width, image_width=input_image.width,
tile_height=512, tile_height=self.tile_height,
tile_width=512, tile_width=self.tile_width,
min_overlap=128, min_overlap=self.tile_min_overlap,
) )
# Convert the input image to a torch.Tensor. # Convert the input image to a torch.Tensor.
@ -299,7 +304,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
controlnet_data = self.run_controlnet( controlnet_data = self.run_controlnet(
image=image_tile_pil, image=image_tile_pil,
controlnet_model=controlnet_model, controlnet_model=controlnet_model,
weight=1.0, weight=self.control_weight,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
width=width, width=width,
height=height, height=height,