mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
This commit is contained in:
parent
f613c073c1
commit
297931f5d9
@ -34,7 +34,7 @@ from .image import ImageOutput, build_image_output, PILInvocationConfig
|
|||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="processed image")
|
image: ImageField = Field(default=None, description="processed image")
|
||||||
control_model: Optional[str] = Field(default=None, description="control model used")
|
control_model: Optional[str] = Field(default=None, description="control model used")
|
||||||
control_weight: Optional[float] = Field(default=None, description="weight given to controlnet")
|
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="% of total steps at which controlnet is first applied")
|
description="% of total steps at which controlnet is first applied")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
@ -61,7 +61,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="image to process")
|
image: ImageField = Field(default=None, description="image to process")
|
||||||
control_model: str = Field(default=None, description="control model to use")
|
control_model: str = Field(default=None, description="control model to use")
|
||||||
control_weight: float = Field(default=0.5, ge=0, le=1, description="weight given to controlnet")
|
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
|
||||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="% of total steps at which controlnet is first applied")
|
description="% of total steps at which controlnet is first applied")
|
||||||
|
@ -265,24 +265,25 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
# print("type of control input: ", type(self.control))
|
# print("type of control input: ", type(self.control))
|
||||||
if self.control is None:
|
if self.control is None:
|
||||||
print("control input is None")
|
# print("control input is None")
|
||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(self.control, list) and len(self.control) == 0:
|
elif isinstance(self.control, list) and len(self.control) == 0:
|
||||||
print("control input is empty list")
|
# print("control input is empty list")
|
||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(self.control, ControlField):
|
elif isinstance(self.control, ControlField):
|
||||||
print("control input is ControlField")
|
# print("control input is ControlField")
|
||||||
control_list = [self.control]
|
control_list = [self.control]
|
||||||
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
elif isinstance(self.control, list) and len(self.control) > 0 and isinstance(self.control[0], ControlField):
|
||||||
print("control input is list[ControlField]")
|
# print("control input is list[ControlField]")
|
||||||
control_list = self.control
|
control_list = self.control
|
||||||
else:
|
else:
|
||||||
print("input control is unrecognized:", type(self.control))
|
#print("input control is unrecognized:", type(self.control))
|
||||||
control_list = None
|
control_list = None
|
||||||
|
|
||||||
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
#if (self.control is None or (isinstance(self.control, list) and len(self.control)==0)):
|
||||||
if (control_list is None):
|
if (control_list is None):
|
||||||
control_models = None
|
control_models = None
|
||||||
|
control_data = None
|
||||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
else:
|
else:
|
||||||
# FIXME: add checks to skip entry if model or image is None
|
# FIXME: add checks to skip entry if model or image is None
|
||||||
|
@ -670,15 +670,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
latent_control_input = latent_model_input
|
latent_control_input = latent_model_input
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (1 ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
for i, control_datum in enumerate(control_data):
|
for i, control_datum in enumerate(control_data):
|
||||||
# print("controlnet", i, "==>", type(control_datum))
|
# print("controlnet", i, "==>", type(control_datum))
|
||||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
# apply_control_this_step = step_index >= first_control_step and step_index <= last_control_step
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
print("running controlnet", i, "for step", step_index)
|
# print("running controlnet", i, "for step", step_index)
|
||||||
down_samples, mid_sample = control_datum.model(
|
down_samples, mid_sample = control_datum.model(
|
||||||
sample=latent_control_input,
|
sample=latent_control_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
@ -709,8 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_res_samples,
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
mid_block_additional_residual=mid_block_res_sample,
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
Loading…
Reference in New Issue
Block a user