diff --git a/invokeai/app/invocations/tiled_stable_diffusion_refine.py b/invokeai/app/invocations/tiled_stable_diffusion_refine.py index 592a00f71c..7c4cf6fe92 100644 --- a/invokeai/app/invocations/tiled_stable_diffusion_refine.py +++ b/invokeai/app/invocations/tiled_stable_diffusion_refine.py @@ -55,6 +55,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): negative_conditioning: ConditioningField = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection ) + # TODO(ryand): Add multiple-of validation. + tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.") + tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.") + tile_min_overlap: int = InputField(default=16, gt=0, description="Minimum overlap between tiles.") steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale") denoising_start: float = InputField( @@ -93,6 +97,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): control_model: ModelIdentifierField = InputField( description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel ) + control_weight: float = InputField(default=0.6) @field_validator("cfg_scale") def ge_one(cls, v: list[float] | float) -> list[float] | float: @@ -173,9 +178,9 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): tiles = calc_tiles_min_overlap( image_height=input_image.height, image_width=input_image.width, - tile_height=512, - tile_width=512, - min_overlap=128, + tile_height=self.tile_height, + tile_width=self.tile_width, + min_overlap=self.tile_min_overlap, ) # Convert the input image to a torch.Tensor. @@ -299,7 +304,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): controlnet_data = self.run_controlnet( image=image_tile_pil, controlnet_model=controlnet_model, - weight=1.0, + weight=self.control_weight, do_classifier_free_guidance=True, width=width, height=height,