Compare commits

...

1 Commits

Author SHA1 Message Date
e4a45341c8 Controlnet implementation for sequential execution 2023-06-16 02:42:32 +03:00
3 changed files with 239 additions and 103 deletions

View File

@ -46,6 +46,7 @@ from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
ControlNetData,
)
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
@ -214,15 +215,6 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
@dataclass(frozen=True)
class ConditioningData:
@ -660,76 +652,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=control_latent_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
x=unet_latent_input,
@ -737,10 +659,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
control_data=control_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
)
# compute the previous noisy sample x_t -> x_t-1
@ -1091,7 +1012,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
#cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
#if do_classifier_free_guidance and not cfg_injection:
# image = torch.cat([image] * 2)
return image

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin
from .cross_attention_map_saving import AttentionMapSaver
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, ControlNetData

View File

@ -1,11 +1,14 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pydantic import Field
from math import ceil
from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np
import torch
import math
from diffusers import UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
@ -40,6 +43,17 @@ class PostprocessingSettings:
v_symmetry_time_pct: Optional[float]
# TODO: pydantic Field work with dataclasses?
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@ -182,8 +196,9 @@ class InvokeAIDiffuserComponent:
conditioning: Union[torch.Tensor, dict],
# unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
step_index: int,
total_step_count: int,
control_data: Optional[List[ControlNetData]],
**kwargs,
):
"""
@ -213,31 +228,56 @@ class InvokeAIDiffuserComponent:
)
)
if self.sequential_guidance:
down_block_res_samples, mid_block_res_sample = self._run_controlnet_sequentially(
unconditioning=unconditioning,
conditioning=conditioning,
control_data=control_data,
sample=x,
timestep=sigma,
step_index=step_index,
total_step_count=total_step_count,
)
else:
down_block_res_samples, mid_block_res_sample = self._run_controlnet_normally(
unconditioning=unconditioning,
conditioning=conditioning,
control_data=control_data,
sample=x,
timestep=sigma,
step_index=step_index,
total_step_count=total_step_count,
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
elif wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
elif self.sequential_guidance:
elif True: #self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
else:
@ -245,7 +285,10 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
x, sigma, unconditioning, conditioning, **kwargs,
x, sigma, unconditioning, conditioning,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
**kwargs,
)
combined_next_x = self._combine(
@ -293,16 +336,160 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
def _run_controlnet_normally(
self,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
control_data: List[ControlNetData],
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
):
if control_data is None:
return (None, None)
down_block_res_samples, mid_block_res_sample = None, None
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if cfg_injection:
control_sample = sample
control_timestep = timestep
control_image_tensor = control_datum.image_tensor
encoder_hidden_states = conditioning # TODO: ask bug
else:
control_sample = torch.cat([sample] * 2)
control_timestep = torch.cat([timestep] * 2)
control_image_tensor = torch.cat([control_datum.image_tensor] * 2)
encoder_hidden_states = torch.cat([unconditioning, conditioning])
if isinstance(control_datum.weight, list):
weight = control_datum.weight[step_index]
else:
weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=control_sample,
timestep=control_timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
def _run_controlnet_sequentially(
self,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
control_data: List[ControlNetData],
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
):
if control_data is None:
return (None, None)
down_block_res_samples, mid_block_res_sample = None, None
for i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
if isinstance(control_datum.weight, list):
weight = control_datum.weight[step_index]
else:
weight = control_datum.weight
# controlnet(s) inference
cond_down_samples, cond_mid_sample = control_datum.model(
sample=sample,
timestep=timestep,
encoder_hidden_states=conditioning, # TODO: ask bug
controlnet_cond=control_datum.image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
uncond_down_samples = [torch.zeros_like(d) for d in cond_down_samples]
uncond_mid_sample = torch.zeros_like(cond_mid_sample)
else:
uncond_down_samples, uncond_mid_sample = control_datum.model(
sample=sample,
timestep=timestep,
encoder_hidden_states=unconditioning,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
down_samples = [torch.cat([ud, cd]) for ud, cd in zip(uncond_down_samples, cond_down_samples)]
mid_sample = torch.cat([uncond_mid_sample, cond_mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
def _apply_standard_conditioning(
self,
x: torch.Tensor,
sigma: torch.Tensor,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
**kwargs
):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs,
)
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings, **kwargs)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
# TODO: check if this still present
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
@ -310,15 +497,43 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning_sequentially(
self,
x: torch.Tensor,
sigma,
sigma: torch.Tensor,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
down_block_additional_residuals, # from controlnet(s)
mid_block_additional_residual, # from controlnet(s)
**kwargs,
):
# split controlnet data to cond and uncond
if down_block_additional_residuals is None:
uncond_down_block_res_samples = None
cond_down_block_res_samples = None
uncond_mid_block_res_sample = None
cond_mid_block_res_sample = None
else:
uncond_down_block_res_samples = []
cond_down_block_res_samples = []
for d in down_block_additional_residuals:
ud, cd = d.chunk(2)
uncond_down_block_res_samples.append(ud)
cond_down_block_res_samples.append(cd)
uncond_mid_block_res_sample, cond_mid_block_res_sample = mid_block_additional_residual.chunk(2)
# low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
unconditioned_next_x = self.model_forward_callback(
x, sigma, unconditioning, **kwargs,
down_block_additional_residuals=uncond_down_block_res_samples,
mid_block_additional_residual=uncond_mid_block_res_sample,
)
conditioned_next_x = self.model_forward_callback(
x, sigma, conditioning, **kwargs,
down_block_additional_residuals=cond_down_block_res_samples,
mid_block_additional_residual=cond_mid_block_res_sample,
)
if conditioned_next_x.device.type == "mps":
# TODO: check if still present
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x