mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
onnx-testi
...
feat/contr
Author | SHA1 | Date | |
---|---|---|---|
e4a45341c8 |
@ -46,6 +46,7 @@ from .diffusion import (
|
|||||||
AttentionMapSaver,
|
AttentionMapSaver,
|
||||||
InvokeAIDiffuserComponent,
|
InvokeAIDiffuserComponent,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
|
ControlNetData,
|
||||||
)
|
)
|
||||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||||
from .textual_inversion_manager import TextualInversionManager
|
from .textual_inversion_manager import TextualInversionManager
|
||||||
@ -214,15 +215,6 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ControlNetData:
|
|
||||||
model: ControlNetModel = Field(default=None)
|
|
||||||
image_tensor: torch.Tensor = Field(default=None)
|
|
||||||
weight: Union[float, List[float]] = Field(default=1.0)
|
|
||||||
begin_step_percent: float = Field(default=0.0)
|
|
||||||
end_step_percent: float = Field(default=1.0)
|
|
||||||
control_mode: str = Field(default="balanced")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -660,76 +652,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
# default is no controlnet, so set controlnet processing output to None
|
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
|
||||||
|
|
||||||
if control_data is not None:
|
|
||||||
# control_data should be type List[ControlNetData]
|
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
|
||||||
for i, control_datum in enumerate(control_data):
|
|
||||||
control_mode = control_datum.control_mode
|
|
||||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
|
||||||
# that are combined at higher level to make control_mode enum
|
|
||||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
|
||||||
# or default weighting (if False)
|
|
||||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
|
||||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
|
||||||
# or the default both conditional and unconditional (if False)
|
|
||||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
|
||||||
|
|
||||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
|
||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
|
||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
|
||||||
|
|
||||||
if cfg_injection:
|
|
||||||
control_latent_input = unet_latent_input
|
|
||||||
else:
|
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
|
||||||
# classifier_free_guidance is <= 1.0 ?)
|
|
||||||
control_latent_input = torch.cat([unet_latent_input] * 2)
|
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
|
||||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
|
|
||||||
else:
|
|
||||||
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
|
||||||
conditioning_data.text_embeddings])
|
|
||||||
if isinstance(control_datum.weight, list):
|
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
|
||||||
else:
|
|
||||||
# if controlnet has a single weight, use it for all steps
|
|
||||||
controlnet_weight = control_datum.weight
|
|
||||||
|
|
||||||
# controlnet(s) inference
|
|
||||||
down_samples, mid_sample = control_datum.model(
|
|
||||||
sample=control_latent_input,
|
|
||||||
timestep=timestep,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
controlnet_cond=control_datum.image_tensor,
|
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
|
||||||
return_dict=False,
|
|
||||||
)
|
|
||||||
if cfg_injection:
|
|
||||||
# Inferred ControlNet only for the conditional batch.
|
|
||||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
||||||
# add 0 to the unconditional batch to keep it unchanged.
|
|
||||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
|
||||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
|
||||||
|
|
||||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
|
||||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
|
||||||
else:
|
|
||||||
# add controlnet outputs together if have multiple controlnets
|
|
||||||
down_block_res_samples = [
|
|
||||||
samples_prev + samples_curr
|
|
||||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
|
||||||
]
|
|
||||||
mid_block_res_sample += mid_sample
|
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
x=unet_latent_input,
|
x=unet_latent_input,
|
||||||
@ -737,10 +659,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
unconditioning=conditioning_data.unconditioned_embeddings,
|
unconditioning=conditioning_data.unconditioned_embeddings,
|
||||||
conditioning=conditioning_data.text_embeddings,
|
conditioning=conditioning_data.text_embeddings,
|
||||||
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
||||||
|
control_data=control_data,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
|
||||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -1091,7 +1012,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
repeat_by = num_images_per_prompt
|
repeat_by = num_images_per_prompt
|
||||||
image = image.repeat_interleave(repeat_by, dim=0)
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
image = image.to(device=device, dtype=dtype)
|
image = image.to(device=device, dtype=dtype)
|
||||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
#cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
if do_classifier_free_guidance and not cfg_injection:
|
#if do_classifier_free_guidance and not cfg_injection:
|
||||||
image = torch.cat([image] * 2)
|
# image = torch.cat([image] * 2)
|
||||||
return image
|
return image
|
||||||
|
@ -3,4 +3,4 @@ Initialization file for invokeai.models.diffusion
|
|||||||
"""
|
"""
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin
|
from .cross_attention_control import InvokeAICrossAttentionMixin
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings, ControlNetData
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pydantic import Field
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Any, Callable, Dict, Optional, Union, List
|
from typing import Any, Callable, Dict, Optional, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.models.controlnet import ControlNetModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
@ -40,6 +43,17 @@ class PostprocessingSettings:
|
|||||||
v_symmetry_time_pct: Optional[float]
|
v_symmetry_time_pct: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: pydantic Field work with dataclasses?
|
||||||
|
@dataclass
|
||||||
|
class ControlNetData:
|
||||||
|
model: ControlNetModel = Field(default=None)
|
||||||
|
image_tensor: torch.Tensor = Field(default=None)
|
||||||
|
weight: Union[float, List[float]] = Field(default=1.0)
|
||||||
|
begin_step_percent: float = Field(default=0.0)
|
||||||
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
control_mode: str = Field(default="balanced")
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
"""
|
"""
|
||||||
The aim of this component is to provide a single place for code that can be applied identically to
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
@ -182,8 +196,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning: Union[torch.Tensor, dict],
|
conditioning: Union[torch.Tensor, dict],
|
||||||
# unconditional_guidance_scale: float,
|
# unconditional_guidance_scale: float,
|
||||||
unconditional_guidance_scale: Union[float, List[float]],
|
unconditional_guidance_scale: Union[float, List[float]],
|
||||||
step_index: Optional[int] = None,
|
step_index: int,
|
||||||
total_step_count: Optional[int] = None,
|
total_step_count: int,
|
||||||
|
control_data: Optional[List[ControlNetData]],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -213,31 +228,56 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.sequential_guidance:
|
||||||
|
down_block_res_samples, mid_block_res_sample = self._run_controlnet_sequentially(
|
||||||
|
unconditioning=unconditioning,
|
||||||
|
conditioning=conditioning,
|
||||||
|
control_data=control_data,
|
||||||
|
sample=x,
|
||||||
|
timestep=sigma,
|
||||||
|
step_index=step_index,
|
||||||
|
total_step_count=total_step_count,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block_res_samples, mid_block_res_sample = self._run_controlnet_normally(
|
||||||
|
unconditioning=unconditioning,
|
||||||
|
conditioning=conditioning,
|
||||||
|
control_data=control_data,
|
||||||
|
sample=x,
|
||||||
|
timestep=sigma,
|
||||||
|
step_index=step_index,
|
||||||
|
total_step_count=total_step_count,
|
||||||
|
)
|
||||||
|
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||||
x, sigma, unconditioning, conditioning, **kwargs,
|
x, sigma, unconditioning, conditioning,
|
||||||
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_cross_attention_controlled_conditioning(
|
) = self._apply_cross_attention_controlled_conditioning(
|
||||||
x,
|
x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do,
|
||||||
sigma,
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
unconditioning,
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif True: #self.sequential_guidance:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x, sigma, unconditioning, conditioning, **kwargs,
|
x, sigma, unconditioning, conditioning,
|
||||||
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -245,7 +285,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x, sigma, unconditioning, conditioning, **kwargs,
|
x, sigma, unconditioning, conditioning,
|
||||||
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
@ -293,16 +336,160 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _run_controlnet_normally(
|
||||||
|
self,
|
||||||
|
unconditioning: torch.Tensor,
|
||||||
|
conditioning: torch.Tensor,
|
||||||
|
control_data: List[ControlNetData],
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
step_index: int,
|
||||||
|
total_step_count: int,
|
||||||
|
):
|
||||||
|
if control_data is None:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
|
for i, control_datum in enumerate(control_data):
|
||||||
|
control_mode = control_datum.control_mode
|
||||||
|
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||||
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
|
|
||||||
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
|
|
||||||
|
if cfg_injection:
|
||||||
|
control_sample = sample
|
||||||
|
control_timestep = timestep
|
||||||
|
control_image_tensor = control_datum.image_tensor
|
||||||
|
encoder_hidden_states = conditioning # TODO: ask bug
|
||||||
|
else:
|
||||||
|
control_sample = torch.cat([sample] * 2)
|
||||||
|
control_timestep = torch.cat([timestep] * 2)
|
||||||
|
control_image_tensor = torch.cat([control_datum.image_tensor] * 2)
|
||||||
|
encoder_hidden_states = torch.cat([unconditioning, conditioning])
|
||||||
|
|
||||||
|
if isinstance(control_datum.weight, list):
|
||||||
|
weight = control_datum.weight[step_index]
|
||||||
|
else:
|
||||||
|
weight = control_datum.weight
|
||||||
|
|
||||||
|
# controlnet(s) inference
|
||||||
|
down_samples, mid_sample = control_datum.model(
|
||||||
|
sample=control_sample,
|
||||||
|
timestep=control_timestep,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
controlnet_cond=control_image_tensor,
|
||||||
|
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||||
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg_injection:
|
||||||
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||||
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||||
|
|
||||||
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
|
else:
|
||||||
|
down_block_res_samples = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||||
|
]
|
||||||
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
|
return down_block_res_samples, mid_block_res_sample
|
||||||
|
|
||||||
|
def _run_controlnet_sequentially(
|
||||||
|
self,
|
||||||
|
unconditioning: torch.Tensor,
|
||||||
|
conditioning: torch.Tensor,
|
||||||
|
control_data: List[ControlNetData],
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
step_index: int,
|
||||||
|
total_step_count: int,
|
||||||
|
):
|
||||||
|
if control_data is None:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
|
for i, control_datum in enumerate(control_data):
|
||||||
|
control_mode = control_datum.control_mode
|
||||||
|
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||||
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
|
|
||||||
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
|
|
||||||
|
if isinstance(control_datum.weight, list):
|
||||||
|
weight = control_datum.weight[step_index]
|
||||||
|
else:
|
||||||
|
weight = control_datum.weight
|
||||||
|
|
||||||
|
# controlnet(s) inference
|
||||||
|
cond_down_samples, cond_mid_sample = control_datum.model(
|
||||||
|
sample=sample,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=conditioning, # TODO: ask bug
|
||||||
|
controlnet_cond=control_datum.image_tensor,
|
||||||
|
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||||
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg_injection:
|
||||||
|
uncond_down_samples = [torch.zeros_like(d) for d in cond_down_samples]
|
||||||
|
uncond_mid_sample = torch.zeros_like(cond_mid_sample)
|
||||||
|
|
||||||
|
else:
|
||||||
|
uncond_down_samples, uncond_mid_sample = control_datum.model(
|
||||||
|
sample=sample,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=unconditioning,
|
||||||
|
controlnet_cond=control_datum.image_tensor,
|
||||||
|
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||||
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_samples = [torch.cat([ud, cd]) for ud, cd in zip(uncond_down_samples, cond_down_samples)]
|
||||||
|
mid_sample = torch.cat([uncond_mid_sample, cond_mid_sample])
|
||||||
|
|
||||||
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
|
else:
|
||||||
|
down_block_res_samples = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||||
|
]
|
||||||
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
|
return down_block_res_samples, mid_block_res_sample
|
||||||
|
|
||||||
|
def _apply_standard_conditioning(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sigma: torch.Tensor,
|
||||||
|
unconditioning: torch.Tensor,
|
||||||
|
conditioning: torch.Tensor,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
|
||||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings, **kwargs)
|
||||||
)
|
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
|
# TODO: check if this still present
|
||||||
# prevent a result filled with zeros. seems to be a torch bug.
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
@ -310,15 +497,43 @@ class InvokeAIDiffuserComponent:
|
|||||||
def _apply_standard_conditioning_sequentially(
|
def _apply_standard_conditioning_sequentially(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma: torch.Tensor,
|
||||||
unconditioning: torch.Tensor,
|
unconditioning: torch.Tensor,
|
||||||
conditioning: torch.Tensor,
|
conditioning: torch.Tensor,
|
||||||
|
down_block_additional_residuals, # from controlnet(s)
|
||||||
|
mid_block_additional_residual, # from controlnet(s)
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# split controlnet data to cond and uncond
|
||||||
|
if down_block_additional_residuals is None:
|
||||||
|
uncond_down_block_res_samples = None
|
||||||
|
cond_down_block_res_samples = None
|
||||||
|
uncond_mid_block_res_sample = None
|
||||||
|
cond_mid_block_res_sample = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
uncond_down_block_res_samples = []
|
||||||
|
cond_down_block_res_samples = []
|
||||||
|
for d in down_block_additional_residuals:
|
||||||
|
ud, cd = d.chunk(2)
|
||||||
|
uncond_down_block_res_samples.append(ud)
|
||||||
|
cond_down_block_res_samples.append(cd)
|
||||||
|
|
||||||
|
uncond_mid_block_res_sample, cond_mid_block_res_sample = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
x, sigma, unconditioning, **kwargs,
|
||||||
|
down_block_additional_residuals=uncond_down_block_res_samples,
|
||||||
|
mid_block_additional_residual=uncond_mid_block_res_sample,
|
||||||
|
)
|
||||||
|
conditioned_next_x = self.model_forward_callback(
|
||||||
|
x, sigma, conditioning, **kwargs,
|
||||||
|
down_block_additional_residuals=cond_down_block_res_samples,
|
||||||
|
mid_block_additional_residual=cond_mid_block_res_sample,
|
||||||
|
)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
|
# TODO: check if still present
|
||||||
# prevent a result filled with zeros. seems to be a torch bug.
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
Reference in New Issue
Block a user