Added support for specifying which step iteration to start using

each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
This commit is contained in:
user1 2023-05-12 04:01:35 -07:00 committed by Kent Keirsey
parent 63d248622c
commit f613c073c1
3 changed files with 46 additions and 37 deletions

View File

@ -33,16 +33,16 @@ from .image import ImageOutput, build_image_output, PILInvocationConfig
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="processed image") image: ImageField = Field(default=None, description="processed image")
# width: Optional[int] = Field(default=None, description="The width of the image in pixels")
# height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# mode: Optional[str] = Field(default=None, description="The mode of the image")
control_model: Optional[str] = Field(default=None, description="control model used") control_model: Optional[str] = Field(default=None, description="control model used")
control_weight: Optional[float] = Field(default=None, description="weight given to controlnet") control_weight: Optional[float] = Field(default=None, description="weight given to controlnet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied")
class Config: class Config:
schema_extra = { schema_extra = {
"required": ["image", "control_model", "control_weight"] "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
# "required": ["type", "image", "width", "height", "mode"]
} }
@ -62,12 +62,11 @@ class ControlNetInvocation(BaseInvocation):
image: ImageField = Field(default=None, description="image to process") image: ImageField = Field(default=None, description="image to process")
control_model: str = Field(default=None, description="control model to use") control_model: str = Field(default=None, description="control model to use")
control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet") control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet")
# TODO: support additional ControlNet parameters (mostly just passthroughs to other nodes with ControlField inputs) # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
# begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
# description="% of total steps at which controlnet is first applied") description="% of total steps at which controlnet is first applied")
# end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
# description="% of total steps at which controlnet is last applied") description="% of total steps at which controlnet is last applied")
# guess_mode: bool = Field(default=False, description="use guess mode (controlnet ignores prompt)")
# fmt: on # fmt: on
@ -77,7 +76,9 @@ class ControlNetInvocation(BaseInvocation):
control=ControlField( control=ControlField(
image=self.image, image=self.image,
control_model=self.control_model, control_model=self.control_model,
control_weight=self.control_weight control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
), ),
) )

View File

@ -317,7 +317,9 @@ class TextToLatentsInvocation(BaseInvocation):
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight) weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item) control_data.append(control_item)
# multi_control = MultiControlNetModel(control_models) # multi_control = MultiControlNetModel(control_models)
# model.control_model = multi_control # model.control_model = multi_control

View File

@ -219,7 +219,8 @@ class ControlNetData:
model: ControlNetModel = Field(default=None) model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None) image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0) weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -657,7 +658,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# if (self.control_model is not None) and (control_image is not None): # default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
if conditioning_data.guidance_scale > 1.0: if conditioning_data.guidance_scale > 1.0:
# expand the latents input to control model if doing classifier free guidance # expand the latents input to control model if doing classifier free guidance
@ -671,6 +674,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# and MultiControlNet (multiple ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data): for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum)) # print("controlnet", i, "==>", type(control_datum))
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# apply_control_this_step = step_index >= first_control_step and step_index <= last_control_step
if step_index >= first_control_step and step_index <= last_control_step:
print("running controlnet", i, "for step", step_index)
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=latent_control_input, sample=latent_control_input,
timestep=timestep, timestep=timestep,
@ -682,7 +690,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
guess_mode=False, guess_mode=False,
return_dict=False, return_dict=False,
) )
if i == 0: if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else: else:
# add controlnet outputs together if have multiple controlnets # add controlnet outputs together if have multiple controlnets
@ -691,8 +699,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
] ]
mid_block_res_sample += mid_sample mid_block_res_sample += mid_sample
else:
down_block_res_samples, mid_block_res_sample = None, None
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(