mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes
This commit is contained in:
parent
8b7fac75ed
commit
8495764d45
@ -105,7 +105,9 @@ class ControlField(BaseModel):
|
|||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
guess_mode: bool = Field(default=False, description="Toggle for guess mode")
|
# guess_mode: bool = Field(default=False, description="Toggle for guess mode")
|
||||||
|
cfg_injection: bool = Field(default=False, description="Toggle for cfg injection")
|
||||||
|
soft_injection: bool = Field(default=False, description="Toggle for soft injection")
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def abs_le_one(cls, v):
|
def abs_le_one(cls, v):
|
||||||
"""validate that all abs(values) are <=1"""
|
"""validate that all abs(values) are <=1"""
|
||||||
@ -151,7 +153,9 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
guess_mode: bool = Field(default=False, description="Toggle for guess mode")
|
# guess_mode: bool = Field(default=False, description="Toggle for guess mode")
|
||||||
|
cfg_injection: bool = Field(default=False, description="Toggle for cfg injection")
|
||||||
|
soft_injection: bool = Field(default=False, description="Toggle for soft injection")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -177,7 +181,9 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
control_weight=self.control_weight,
|
control_weight=self.control_weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
guess_mode=self.guess_mode,
|
# guess_mode=self.guess_mode,
|
||||||
|
cfg_injection=self.cfg_injection,
|
||||||
|
soft_injection=self.soft_injection,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -337,14 +337,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# num_images_per_prompt=num_images_per_prompt,
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
device=control_model.device,
|
device=control_model.device,
|
||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
guess_mode=control_info.guess_mode,
|
# guess_mode=control_info.guess_mode,
|
||||||
|
cfg_injection=control_info.cfg_injection,
|
||||||
|
soft_injection=control_info.soft_injection,
|
||||||
)
|
)
|
||||||
control_item = ControlNetData(model=control_model,
|
control_item = ControlNetData(model=control_model,
|
||||||
image_tensor=control_image,
|
image_tensor=control_image,
|
||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent,
|
end_step_percent=control_info.end_step_percent,
|
||||||
guess_mode=control_info.guess_mode,)
|
# guess_mode=control_info.guess_mode,
|
||||||
|
cfg_injection=control_info.cfg_injection,
|
||||||
|
soft_injection=control_info.soft_injection,
|
||||||
|
)
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
|
@ -223,6 +223,8 @@ class ControlNetData:
|
|||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL
|
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL
|
||||||
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
|
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
|
||||||
|
cfg_injection: bool = Field(default=False)
|
||||||
|
soft_injection: bool = Field(default=False)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -695,7 +697,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
guess_mode = control_datum.guess_mode
|
# guess_mode = control_datum.guess_mode
|
||||||
|
guess_mode = control_datum.cfg_injection
|
||||||
if guess_mode:
|
if guess_mode:
|
||||||
control_latent_input = unet_latent_input
|
control_latent_input = unet_latent_input
|
||||||
else:
|
else:
|
||||||
@ -740,7 +743,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
# cross_attention_kwargs,
|
# cross_attention_kwargs,
|
||||||
guess_mode=guess_mode,
|
# guess_mode=guess_mode,
|
||||||
|
guess_mode=control_datum.soft_injection,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
print("finished ControlNetModel() call, step", step_index)
|
print("finished ControlNetModel() call, step", step_index)
|
||||||
@ -1100,6 +1104,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
guess_mode=False,
|
guess_mode=False,
|
||||||
|
soft_injection=False,
|
||||||
|
cfg_injection=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not isinstance(image, torch.Tensor):
|
if not isinstance(image, torch.Tensor):
|
||||||
@ -1130,6 +1136,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
repeat_by = num_images_per_prompt
|
repeat_by = num_images_per_prompt
|
||||||
image = image.repeat_interleave(repeat_by, dim=0)
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
image = image.to(device=device, dtype=dtype)
|
image = image.to(device=device, dtype=dtype)
|
||||||
if do_classifier_free_guidance and not guess_mode:
|
# if do_classifier_free_guidance and not guess_mode:
|
||||||
|
if do_classifier_free_guidance and not cfg_injection:
|
||||||
image = torch.cat([image] * 2)
|
image = torch.cat([image] * 2)
|
||||||
return image
|
return image
|
||||||
|
Loading…
Reference in New Issue
Block a user