mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixed use of ControlNet control_weight parameter
This commit is contained in:
parent
c0863fa20f
commit
714ad6dbb8
@ -35,8 +35,8 @@ class ControlField(BaseModel):
|
|||||||
# width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
# width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
||||||
# height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
# height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
||||||
# mode: Optional[str] = Field(default=None, description="The mode of the image")
|
# mode: Optional[str] = Field(default=None, description="The mode of the image")
|
||||||
control_model: Optional[str] = Field(default=None, description="The control model used")
|
control_model: Optional[str] = Field(default=None, description="control model used")
|
||||||
control_weight: Optional[float] = Field(default=None, description="The control weight used")
|
control_weight: Optional[float] = Field(default=None, description="weight given to controlnet")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -62,7 +62,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="image to process")
|
image: ImageField = Field(default=None, description="image to process")
|
||||||
control_model: str = Field(default=None, description="control model to use")
|
control_model: str = Field(default=None, description="control model to use")
|
||||||
control_weight: float = Field(default=0.5, ge=0, le=1, description="control weight")
|
control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet")
|
||||||
# TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs)
|
# TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs)
|
||||||
# begin_step_percent: float = Field(default=0, ge=0, le=1,
|
# begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
# description="% of total steps at which controlnet is first applied")
|
# description="% of total steps at which controlnet is first applied")
|
||||||
@ -78,7 +78,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
control=ControlField(
|
control=ControlField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
control_model=self.control_model,
|
control_model=self.control_model,
|
||||||
control_weight=self.control_weight,
|
control_weight=self.control_weight
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -350,6 +350,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
control_image=control_images,
|
control_image=control_images,
|
||||||
|
control_weight=control_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
@ -647,11 +647,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
||||||
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
||||||
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
|
control_weight = kwargs.get("control_weight", 1.0) # control_weight default is 1.0
|
||||||
# handling case where using multiple control models but only specifying single control_scale
|
# handling case where using multiple control models but only specifying single control_weight
|
||||||
# so reshape control_scale to match number of control models
|
# so reshape control_weight to match number of control models
|
||||||
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
|
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_weight, float):
|
||||||
control_scale = [control_scale] * len(self.control_model.nets)
|
control_weight = [control_weight] * len(self.control_model.nets)
|
||||||
if conditioning_data.guidance_scale > 1.0:
|
if conditioning_data.guidance_scale > 1.0:
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
@ -660,13 +660,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
latent_control_input = latent_model_input
|
latent_control_input = latent_model_input
|
||||||
# controlnet inference
|
# controlnet inference
|
||||||
|
|
||||||
|
print("control_weight: ", control_weight)
|
||||||
down_block_res_samples, mid_block_res_sample = self.control_model(
|
down_block_res_samples, mid_block_res_sample = self.control_model(
|
||||||
latent_control_input,
|
latent_control_input,
|
||||||
timestep,
|
timestep,
|
||||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
conditioning_data.text_embeddings]),
|
conditioning_data.text_embeddings]),
|
||||||
controlnet_cond=control_image,
|
controlnet_cond=control_image,
|
||||||
conditioning_scale=control_scale,
|
conditioning_scale=control_weight,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -41,7 +41,7 @@ print("testing Txt2Img.generate() with control_image arg")
|
|||||||
outputs = txt2img_canny.generate(
|
outputs = txt2img_canny.generate(
|
||||||
prompt="old man",
|
prompt="old man",
|
||||||
control_image=canny_image,
|
control_image=canny_image,
|
||||||
control_scale=1.0,
|
control_weight=1.0,
|
||||||
seed=0,
|
seed=0,
|
||||||
num_steps=30,
|
num_steps=30,
|
||||||
precision="float16",
|
precision="float16",
|
||||||
|
Loading…
Reference in New Issue
Block a user