mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
next-test-
...
feat/contr
Author | SHA1 | Date | |
---|---|---|---|
e4a45341c8 |
@ -46,6 +46,7 @@ from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
ControlNetData,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
@ -214,15 +215,6 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor = Field(default=None)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
control_mode: str = Field(default="balanced")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
@ -660,76 +652,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# default is no controlnet, so set controlnet processing output to None
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||
# that are combined at higher level to make control_mode enum
|
||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||
# or default weighting (if False)
|
||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||
# or the default both conditional and unconditional (if False)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
|
||||
if cfg_injection:
|
||||
control_latent_input = unet_latent_input
|
||||
else:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
control_latent_input = torch.cat([unet_latent_input] * 2)
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
|
||||
else:
|
||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings])
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=control_latent_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
if cfg_injection:
|
||||
# Inferred ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
x=unet_latent_input,
|
||||
@ -737,10 +659,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
unconditioning=conditioning_data.unconditioned_embeddings,
|
||||
conditioning=conditioning_data.text_embeddings,
|
||||
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
||||
control_data=control_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@ -1091,7 +1012,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
repeat_by = num_images_per_prompt
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
image = torch.cat([image] * 2)
|
||||
#cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
#if do_classifier_free_guidance and not cfg_injection:
|
||||
# image = torch.cat([image] * 2)
|
||||
return image
|
||||
|
@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, ControlNetData
|
||||
|
@ -1,11 +1,14 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from math import ceil
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import math
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
@ -40,6 +43,17 @@ class PostprocessingSettings:
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
# TODO: pydantic Field work with dataclasses?
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor = Field(default=None)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
control_mode: str = Field(default="balanced")
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@ -182,8 +196,9 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
# unconditional_guidance_scale: float,
|
||||
unconditional_guidance_scale: Union[float, List[float]],
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
control_data: Optional[List[ControlNetData]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -213,31 +228,56 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
)
|
||||
|
||||
if self.sequential_guidance:
|
||||
down_block_res_samples, mid_block_res_sample = self._run_controlnet_sequentially(
|
||||
unconditioning=unconditioning,
|
||||
conditioning=conditioning,
|
||||
control_data=control_data,
|
||||
sample=x,
|
||||
timestep=sigma,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
)
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = self._run_controlnet_normally(
|
||||
unconditioning=unconditioning,
|
||||
conditioning=conditioning,
|
||||
control_data=control_data,
|
||||
sample=x,
|
||||
timestep=sigma,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
)
|
||||
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x, sigma, unconditioning, conditioning,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
elif wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
elif True: #self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x, sigma, unconditioning, conditioning,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -245,7 +285,10 @@ class InvokeAIDiffuserComponent:
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x, sigma, unconditioning, conditioning, **kwargs,
|
||||
x, sigma, unconditioning, conditioning,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
@ -293,16 +336,160 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
def _run_controlnet_normally(
|
||||
self,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
control_data: List[ControlNetData],
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
):
|
||||
if control_data is None:
|
||||
return (None, None)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
|
||||
if cfg_injection:
|
||||
control_sample = sample
|
||||
control_timestep = timestep
|
||||
control_image_tensor = control_datum.image_tensor
|
||||
encoder_hidden_states = conditioning # TODO: ask bug
|
||||
else:
|
||||
control_sample = torch.cat([sample] * 2)
|
||||
control_timestep = torch.cat([timestep] * 2)
|
||||
control_image_tensor = torch.cat([control_datum.image_tensor] * 2)
|
||||
encoder_hidden_states = torch.cat([unconditioning, conditioning])
|
||||
|
||||
if isinstance(control_datum.weight, list):
|
||||
weight = control_datum.weight[step_index]
|
||||
else:
|
||||
weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=control_sample,
|
||||
timestep=control_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=control_image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if cfg_injection:
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def _run_controlnet_sequentially(
|
||||
self,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
control_data: List[ControlNetData],
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
):
|
||||
if control_data is None:
|
||||
return (None, None)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
for i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
|
||||
if isinstance(control_datum.weight, list):
|
||||
weight = control_datum.weight[step_index]
|
||||
else:
|
||||
weight = control_datum.weight
|
||||
|
||||
# controlnet(s) inference
|
||||
cond_down_samples, cond_mid_sample = control_datum.model(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=conditioning, # TODO: ask bug
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if cfg_injection:
|
||||
uncond_down_samples = [torch.zeros_like(d) for d in cond_down_samples]
|
||||
uncond_mid_sample = torch.zeros_like(cond_mid_sample)
|
||||
|
||||
else:
|
||||
uncond_down_samples, uncond_mid_sample = control_datum.model(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=unconditioning,
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
down_samples = [torch.cat([ud, cd]) for ud, cd in zip(uncond_down_samples, cond_down_samples)]
|
||||
mid_sample = torch.cat([uncond_mid_sample, cond_mid_sample])
|
||||
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
)
|
||||
|
||||
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings, **kwargs)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# TODO: check if this still present
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@ -310,15 +497,43 @@ class InvokeAIDiffuserComponent:
|
||||
def _apply_standard_conditioning_sequentially(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor,
|
||||
conditioning: torch.Tensor,
|
||||
down_block_additional_residuals, # from controlnet(s)
|
||||
mid_block_additional_residual, # from controlnet(s)
|
||||
**kwargs,
|
||||
):
|
||||
# split controlnet data to cond and uncond
|
||||
if down_block_additional_residuals is None:
|
||||
uncond_down_block_res_samples = None
|
||||
cond_down_block_res_samples = None
|
||||
uncond_mid_block_res_sample = None
|
||||
cond_mid_block_res_sample = None
|
||||
|
||||
else:
|
||||
uncond_down_block_res_samples = []
|
||||
cond_down_block_res_samples = []
|
||||
for d in down_block_additional_residuals:
|
||||
ud, cd = d.chunk(2)
|
||||
uncond_down_block_res_samples.append(ud)
|
||||
cond_down_block_res_samples.append(cd)
|
||||
|
||||
uncond_mid_block_res_sample, cond_mid_block_res_sample = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, unconditioning, **kwargs,
|
||||
down_block_additional_residuals=uncond_down_block_res_samples,
|
||||
mid_block_additional_residual=uncond_mid_block_res_sample,
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, conditioning, **kwargs,
|
||||
down_block_additional_residuals=cond_down_block_res_samples,
|
||||
mid_block_additional_residual=cond_mid_block_res_sample,
|
||||
)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# TODO: check if still present
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
Reference in New Issue
Block a user