Separate inputs in denoise context

This commit is contained in:
Sergey Borisov 2024-07-16 19:30:29 +03:00
parent 9f088d1bf5
commit 608cbe3f5c
4 changed files with 33 additions and 24 deletions

View File

@ -40,7 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
@ -768,13 +768,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
denoise_ctx = DenoiseContext(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
),
unet=None,
scheduler=scheduler,
)

View File

@ -30,8 +30,8 @@ class UNetKwargs:
@dataclass
class DenoiseContext:
latents: torch.Tensor
class DenoiseInputs:
orig_latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
@ -39,10 +39,15 @@ class DenoiseContext:
timesteps: torch.Tensor
init_timestep: torch.Tensor
@dataclass
class DenoiseContext:
inputs: DenoiseInputs
scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None
orig_latents: Optional[torch.Tensor] = None
latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None

View File

@ -22,25 +22,27 @@ class StableDiffusionBackend:
self.sequential_guidance = config.sequential_guidance
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if ctx.init_timestep.shape[0] == 0:
return ctx.latents
if ctx.inputs.init_timestep.shape[0] == 0:
return ctx.inputs.orig_latents
ctx.orig_latents = ctx.latents.clone()
ctx.latents = ctx.inputs.orig_latents.clone()
if ctx.noise is not None:
if ctx.inputs.noise is not None:
batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size))
ctx.latents = ctx.scheduler.add_noise(
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
)
# if no work to do, return latents
if ctx.timesteps.shape[0] == 0:
if ctx.inputs.timesteps.shape[0] == 0:
return ctx.latents
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.callbacks.pre_step(ctx, ext_manager)
@ -80,7 +82,7 @@ class StableDiffusionBackend:
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
# clean up locals
ctx.latent_model_input = None
@ -92,7 +94,7 @@ class StableDiffusionBackend:
@staticmethod
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.conditioning_data.guidance_scale
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
@ -109,12 +111,12 @@ class StableDiffusionBackend:
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
),
)
ctx.conditioning_mode = conditioning_mode
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet]
ext_manager.callbacks.pre_unet(ctx, ext_manager)

View File

@ -35,7 +35,7 @@ class PreviewExt(ExtensionBase):
PipelineIntermediateState(
step=-1,
order=ctx.scheduler.order,
total_steps=len(ctx.timesteps),
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
latents=ctx.latents,
)
@ -55,7 +55,7 @@ class PreviewExt(ExtensionBase):
PipelineIntermediateState(
step=ctx.step_index,
order=ctx.scheduler.order,
total_steps=len(ctx.timesteps),
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
latents=ctx.step_output.prev_sample,
predicted_original=predicted_original, # TODO: is there any reason for additional field?