Fixed use of ControlNet control_weight parameter

This commit is contained in:
user1 2023-05-09 00:30:45 -07:00 committed by Kent Keirsey
parent c0863fa20f
commit 714ad6dbb8
4 changed files with 14 additions and 11 deletions

View File

@ -35,8 +35,8 @@ class ControlField(BaseModel):
# width: Optional[int] = Field(default=None, description="The width of the image in pixels")
# height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# mode: Optional[str] = Field(default=None, description="The mode of the image")
control_model: Optional[str] = Field(default=None, description="The control model used")
control_weight: Optional[float] = Field(default=None, description="The control weight used")
control_model: Optional[str] = Field(default=None, description="control model used")
control_weight: Optional[float] = Field(default=None, description="weight given to controlnet")
class Config:
schema_extra = {
@ -62,7 +62,7 @@ class ControlNetInvocation(BaseInvocation):
# Inputs
image: ImageField = Field(default=None, description="image to process")
control_model: str = Field(default=None, description="control model to use")
control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight")
control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet")
# TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs)
# begin_step_percent: float = Field(default=0, ge=0, le=1,
# description="% of total steps at which controlnet is first applied")
@ -78,7 +78,7 @@ class ControlNetInvocation(BaseInvocation):
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
control_weight=self.control_weight
),
)

View File

@ -350,6 +350,7 @@ class TextToLatentsInvocation(BaseInvocation):
conditioning_data=conditioning_data,
callback=step_callback,
control_image=control_images,
control_weight=control_weights,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699

View File

@ -647,11 +647,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
# handling case where using multiple control models but only specifying single control_scale
# so reshape control_scale to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
control_scale = [control_scale] * len(self.control_model.nets)
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
# handling case where using multiple control models but only specifying single control_weight
# so reshape control_weight to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
control_weight = [control_weight] * len(self.control_model.nets)
if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
@ -660,13 +660,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
latent_control_input = latent_model_input
# controlnet inference
print("control_weight: ", control_weight)
down_block_res_samples, mid_block_res_sample = self.control_model(
latent_control_input,
timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_image,
conditioning_scale=control_scale,
conditioning_scale=control_weight,
return_dict=False,
)
else:

View File

@ -41,7 +41,7 @@ print("testing Txt2Img.generate() with control_image arg")
outputs = txt2img_canny.generate(
prompt="old man",
control_image=canny_image,
control_scale=1.0,
control_weight=1.0,
seed=0,
num_steps=30,
precision="float16",