mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First pass at ControlNet "guess mode" implementation.
This commit is contained in:
parent
c647056287
commit
fd715026a7
@ -1,7 +1,7 @@
|
|||||||
# InvokeAI nodes for ControlNet image preprocessors
|
# InvokeAI nodes for ControlNet image preprocessors
|
||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import float
|
from builtins import float, bool
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Literal, Optional, Union, List
|
from typing import Literal, Optional, Union, List
|
||||||
@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
|
# CONTROLNET_MODE_VALUES = Literal[tuple(["BALANCED", "PROMPT", "CONTROL"])]
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
@ -104,6 +105,7 @@ class ControlField(BaseModel):
|
|||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
guess_mode: bool = Field(default=False, description="Toggle for guess mode")
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def abs_le_one(cls, v):
|
def abs_le_one(cls, v):
|
||||||
"""validate that all abs(values) are <=1"""
|
"""validate that all abs(values) are <=1"""
|
||||||
@ -149,6 +151,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
guess_mode: bool = Field(default=False, description="Toggle for guess mode")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -174,6 +177,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
control_weight=self.control_weight,
|
control_weight=self.control_weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
guess_mode=self.guess_mode,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -337,12 +337,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# num_images_per_prompt=num_images_per_prompt,
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
device=control_model.device,
|
device=control_model.device,
|
||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
|
guess_mode=control_info.guess_mode,
|
||||||
)
|
)
|
||||||
control_item = ControlNetData(model=control_model,
|
control_item = ControlNetData(model=control_model,
|
||||||
image_tensor=control_image,
|
image_tensor=control_image,
|
||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent)
|
end_step_percent=control_info.end_step_percent,
|
||||||
|
guess_mode=control_info.guess_mode,)
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
|
@ -221,6 +221,8 @@ class ControlNetData:
|
|||||||
weight: Union[float, List[float]] = Field(default=1.0)
|
weight: Union[float, List[float]] = Field(default=1.0)
|
||||||
begin_step_percent: float = Field(default=0.0)
|
begin_step_percent: float = Field(default=0.0)
|
||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL
|
||||||
|
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -656,21 +658,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
|
# # guess mode handling from diffusers
|
||||||
|
# if guess_mode and do_classifier_free_guidance:
|
||||||
|
# # Infer ControlNet only for the conditional batch.
|
||||||
|
# control_model_input = latents
|
||||||
|
# control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||||
|
# controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||||
|
# else:
|
||||||
|
# control_model_input = unet_latent_input
|
||||||
|
# controlnet_prompt_embeds = prompt_embeds
|
||||||
|
|
||||||
# default is no controlnet, so set controlnet processing output to None
|
# default is no controlnet, so set controlnet processing output to None
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
# FIXME: make sure guidance_scale <= 1.0 is handled correctly if doing per-step guidance setting
|
||||||
|
# UPDATE: I think this is fixed now with pydantic validator for cfg_scale?
|
||||||
|
# So we should _never_ have guidance_scale <= 1.0
|
||||||
# if conditioning_data.guidance_scale > 1.0:
|
# if conditioning_data.guidance_scale > 1.0:
|
||||||
if conditioning_data.guidance_scale is not None:
|
# if conditioning_data.guidance_scale is not None:
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
# if guess_mode is False:
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
# # expand the latents input to control model if doing classifier free guidance
|
||||||
# classifier_free_guidance is <= 1.0 ?)
|
# # (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
# # classifier_free_guidance is <= 1.0 ?)
|
||||||
else:
|
# control_latent_input = torch.cat([unet_latent_input] * 2)
|
||||||
latent_control_input = latent_model_input
|
# else:
|
||||||
|
# control_latent_input = unet_latent_input
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
@ -680,24 +695,62 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
# print("running controlnet", i, "for step", step_index)
|
guess_mode = control_datum.guess_mode
|
||||||
|
if guess_mode:
|
||||||
|
control_latent_input = unet_latent_input
|
||||||
|
else:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
control_latent_input = torch.cat([unet_latent_input] * 2)
|
||||||
|
|
||||||
|
print("running controlnet", i, "for step", step_index)
|
||||||
|
print("guess mode: ", guess_mode)
|
||||||
|
print("guess mode type: ", type(guess_mode))
|
||||||
|
if guess_mode: # only using prompt conditioning in unconditioned
|
||||||
|
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
|
||||||
|
else:
|
||||||
|
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings])
|
||||||
|
print("encoder_hidden_states.shape", encoder_hidden_states.shape)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
else:
|
else:
|
||||||
# if controlnet has a single weight, use it for all steps
|
# if controlnet has a single weight, use it for all steps
|
||||||
controlnet_weight = control_datum.weight
|
controlnet_weight = control_datum.weight
|
||||||
|
|
||||||
|
# guess mode handling from diffusers controlnet pipeline:
|
||||||
|
# if guess_mode and do_classifier_free_guidance:
|
||||||
|
# # Infer ControlNet only for the conditional batch.
|
||||||
|
# latent_control_input = latents
|
||||||
|
# latent_control_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||||
|
# controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||||
|
# else:
|
||||||
|
# control_model_input = unet_latent_input
|
||||||
|
# controlnet_prompt_embeds = prompt_embeds
|
||||||
|
|
||||||
|
# controlnet(s) inference
|
||||||
down_samples, mid_sample = control_datum.model(
|
down_samples, mid_sample = control_datum.model(
|
||||||
sample=latent_control_input,
|
sample=control_latent_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
# encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
conditioning_data.text_embeddings]),
|
# conditioning_data.text_embeddings]),
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight,
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
# cross_attention_kwargs,
|
# cross_attention_kwargs,
|
||||||
guess_mode=False,
|
guess_mode=guess_mode,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
|
print("finished ControlNetModel() call, step", step_index)
|
||||||
|
if guess_mode:
|
||||||
|
# Inferred ControlNet only for the conditional batch.
|
||||||
|
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||||
|
# add 0 to the unconditional batch to keep it unchanged.
|
||||||
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||||
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||||
|
|
||||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
else:
|
else:
|
||||||
@ -708,13 +761,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
]
|
]
|
||||||
mid_block_res_sample += mid_sample
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
|
# guess mode handling from diffusers controlnet pipeline:
|
||||||
|
# if guess_mode and do_classifier_free_guidance:
|
||||||
|
# # Inferred ControlNet only for the conditional batch.
|
||||||
|
# # To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||||
|
# # add 0 to the unconditional batch to keep it unchanged.
|
||||||
|
# down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||||
|
# mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
latent_model_input,
|
x=unet_latent_input,
|
||||||
t,
|
sigma=t,
|
||||||
conditioning_data.unconditioned_embeddings,
|
unconditioning=conditioning_data.unconditioned_embeddings,
|
||||||
conditioning_data.text_embeddings,
|
conditioning=conditioning_data.text_embeddings,
|
||||||
conditioning_data.guidance_scale,
|
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
@ -1038,6 +1099,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
guess_mode=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not isinstance(image, torch.Tensor):
|
if not isinstance(image, torch.Tensor):
|
||||||
@ -1068,6 +1130,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
repeat_by = num_images_per_prompt
|
repeat_by = num_images_per_prompt
|
||||||
image = image.repeat_interleave(repeat_by, dim=0)
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
image = image.to(device=device, dtype=dtype)
|
image = image.to(device=device, dtype=dtype)
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance and not guess_mode:
|
||||||
image = torch.cat([image] * 2)
|
image = torch.cat([image] * 2)
|
||||||
return image
|
return image
|
||||||
|
Loading…
Reference in New Issue
Block a user