Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes

This commit is contained in:
user1 2023-06-12 23:57:57 -07:00
parent fd715026a7
commit 9e0e26f4c4
3 changed files with 26 additions and 8 deletions

View File

@ -105,7 +105,9 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
guess_mode: bool = Field(default=False, description="Toggle for guess mode") # guess_mode: bool = Field(default=False, description="Toggle for guess mode")
cfg_injection: bool = Field(default=False, description="Toggle for cfg injection")
soft_injection: bool = Field(default=False, description="Toggle for soft injection")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def abs_le_one(cls, v):
"""validate that all abs(values) are <=1""" """validate that all abs(values) are <=1"""
@ -151,7 +153,9 @@ class ControlNetInvocation(BaseInvocation):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
guess_mode: bool = Field(default=False, description="Toggle for guess mode") # guess_mode: bool = Field(default=False, description="Toggle for guess mode")
cfg_injection: bool = Field(default=False, description="Toggle for cfg injection")
soft_injection: bool = Field(default=False, description="Toggle for soft injection")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -177,7 +181,9 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
guess_mode=self.guess_mode, # guess_mode=self.guess_mode,
cfg_injection=self.cfg_injection,
soft_injection=self.soft_injection,
), ),
) )

View File

@ -337,14 +337,19 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
guess_mode=control_info.guess_mode, # guess_mode=control_info.guess_mode,
cfg_injection=control_info.cfg_injection,
soft_injection=control_info.soft_injection,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
guess_mode=control_info.guess_mode,) # guess_mode=control_info.guess_mode,
cfg_injection=control_info.cfg_injection,
soft_injection=control_info.soft_injection,
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data

View File

@ -223,6 +223,8 @@ class ControlNetData:
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL # FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
cfg_injection: bool = Field(default=False)
soft_injection: bool = Field(default=False)
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -695,7 +697,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
guess_mode = control_datum.guess_mode # guess_mode = control_datum.guess_mode
guess_mode = control_datum.cfg_injection
if guess_mode: if guess_mode:
control_latent_input = unet_latent_input control_latent_input = unet_latent_input
else: else:
@ -740,7 +743,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
# cross_attention_kwargs, # cross_attention_kwargs,
guess_mode=guess_mode, # guess_mode=guess_mode,
guess_mode=control_datum.soft_injection,
return_dict=False, return_dict=False,
) )
print("finished ControlNetModel() call, step", step_index) print("finished ControlNetModel() call, step", step_index)
@ -1100,6 +1104,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
dtype=torch.float16, dtype=torch.float16,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
guess_mode=False, guess_mode=False,
soft_injection=False,
cfg_injection=False,
): ):
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
@ -1130,6 +1136,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0) image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode: # if do_classifier_free_guidance and not guess_mode:
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2) image = torch.cat([image] * 2)
return image return image