Fixed use of ControlNet control_weight parameter

This commit is contained in:
user1 2023-05-09 00:30:45 -07:00 committed by Kent Keirsey
parent c0863fa20f
commit 714ad6dbb8
4 changed files with 14 additions and 11 deletions

View File

@ -35,8 +35,8 @@ class ControlField(BaseModel):
# width: Optional[int] = Field(default=None, description="The width of the image in pixels") # width: Optional[int] = Field(default=None, description="The width of the image in pixels")
# height: Optional[int] = Field(default=None, description="The height of the image in pixels") # height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# mode: Optional[str] = Field(default=None, description="The mode of the image") # mode: Optional[str] = Field(default=None, description="The mode of the image")
control_model: Optional[str] = Field(default=None, description="The control model used") control_model: Optional[str] = Field(default=None, description="control model used")
control_weight: Optional[float] = Field(default=None, description="The control weight used") control_weight: Optional[float] = Field(default=None, description="weight given to controlnet")
class Config: class Config:
schema_extra = { schema_extra = {
@ -62,7 +62,7 @@ class ControlNetInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = Field(default=None, description="image to process") image: ImageField = Field(default=None, description="image to process")
control_model: str = Field(default=None, description="control model to use") control_model: str = Field(default=None, description="control model to use")
control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight") control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet")
# TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs) # TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs)
# begin_step_percent: float = Field(default=0, ge=0, le=1, # begin_step_percent: float = Field(default=0, ge=0, le=1,
# description="% of total steps at which controlnet is first applied") # description="% of total steps at which controlnet is first applied")
@ -78,7 +78,7 @@ class ControlNetInvocation(BaseInvocation):
control=ControlField( control=ControlField(
image=self.image, image=self.image,
control_model=self.control_model, control_model=self.control_model,
control_weight=self.control_weight, control_weight=self.control_weight
), ),
) )

View File

@ -350,6 +350,7 @@ class TextToLatentsInvocation(BaseInvocation):
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback, callback=step_callback,
control_image=control_images, control_image=control_images,
control_weight=control_weights,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699

View File

@ -647,11 +647,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if (self.control_model is not None) and (kwargs.get("control_image") is not None): if (self.control_model is not None) and (kwargs.get("control_image") is not None):
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s) control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0 control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
# handling case where using multiple control models but only specifying single control_scale # handling case where using multiple control models but only specifying single control_weight
# so reshape control_scale to match number of control models # so reshape control_weight to match number of control models
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float): if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
control_scale = [control_scale] * len(self.control_model.nets) control_weight = [control_weight] * len(self.control_model.nets)
if conditioning_data.guidance_scale > 1.0: if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance # expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if # (which I think for now is always true, there is conditional elsewhere that stops execution if
@ -660,13 +660,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else: else:
latent_control_input = latent_model_input latent_control_input = latent_model_input
# controlnet inference # controlnet inference
print("control_weight: ", control_weight)
down_block_res_samples, mid_block_res_sample = self.control_model( down_block_res_samples, mid_block_res_sample = self.control_model(
latent_control_input, latent_control_input,
timestep, timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]), conditioning_data.text_embeddings]),
controlnet_cond=control_image, controlnet_cond=control_image,
conditioning_scale=control_scale, conditioning_scale=control_weight,
return_dict=False, return_dict=False,
) )
else: else:

View File

@ -41,7 +41,7 @@ print("testing Txt2Img.generate() with control_image arg")
outputs = txt2img_canny.generate( outputs = txt2img_canny.generate(
prompt="old man", prompt="old man",
control_image=canny_image, control_image=canny_image,
control_scale=1.0, control_weight=1.0,
seed=0, seed=0,
num_steps=30, num_steps=30,
precision="float16", precision="float16",