Separate inputs in denoise context

This commit is contained in:
Sergey Borisov 2024-07-16 19:30:29 +03:00
parent 9f088d1bf5
commit 608cbe3f5c
4 changed files with 33 additions and 24 deletions

View File

@ -40,7 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import ( from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, ControlNetData,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
@ -768,13 +768,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
denoise_ctx = DenoiseContext( denoise_ctx = DenoiseContext(
latents=latents, inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,
noise=noise, noise=noise,
seed=seed, seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
),
unet=None, unet=None,
scheduler=scheduler, scheduler=scheduler,
) )

View File

@ -30,8 +30,8 @@ class UNetKwargs:
@dataclass @dataclass
class DenoiseContext: class DenoiseInputs:
latents: torch.Tensor orig_latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any] scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData conditioning_data: TextConditioningData
noise: Optional[torch.Tensor] noise: Optional[torch.Tensor]
@ -39,10 +39,15 @@ class DenoiseContext:
timesteps: torch.Tensor timesteps: torch.Tensor
init_timestep: torch.Tensor init_timestep: torch.Tensor
@dataclass
class DenoiseContext:
inputs: DenoiseInputs
scheduler: SchedulerMixin scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None unet: Optional[UNet2DConditionModel] = None
orig_latents: Optional[torch.Tensor] = None latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None unet_kwargs: Optional[UNetKwargs] = None

View File

@ -22,25 +22,27 @@ class StableDiffusionBackend:
self.sequential_guidance = config.sequential_guidance self.sequential_guidance = config.sequential_guidance
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if ctx.init_timestep.shape[0] == 0: if ctx.inputs.init_timestep.shape[0] == 0:
return ctx.latents return ctx.inputs.orig_latents
ctx.orig_latents = ctx.latents.clone() ctx.latents = ctx.inputs.orig_latents.clone()
if ctx.noise is not None: if ctx.inputs.noise is not None:
batch_size = ctx.latents.shape[0] batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size)) ctx.latents = ctx.scheduler.add_noise(
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
)
# if no work to do, return latents # if no work to do, return latents
if ctx.timesteps.shape[0] == 0: if ctx.inputs.timesteps.shape[0] == 0:
return ctx.latents return ctx.latents
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low] # ext: preview[pre_denoise_loop, priority=low]
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager) ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models) # ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.callbacks.pre_step(ctx, ext_manager) ext_manager.callbacks.pre_step(ctx, ext_manager)
@ -80,7 +82,7 @@ class StableDiffusionBackend:
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager) ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
# clean up locals # clean up locals
ctx.latent_model_input = None ctx.latent_model_input = None
@ -92,7 +94,7 @@ class StableDiffusionBackend:
@staticmethod @staticmethod
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.conditioning_data.guidance_scale guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list): if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index] guidance_scale = guidance_scale[ctx.step_index]
@ -109,12 +111,12 @@ class StableDiffusionBackend:
timestep=ctx.timestep, timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408 cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, percent_through=ctx.step_index / len(ctx.inputs.timesteps),
), ),
) )
ctx.conditioning_mode = conditioning_mode ctx.conditioning_mode = conditioning_mode
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet] # ext: controlnet/ip/t2i [pre_unet]
ext_manager.callbacks.pre_unet(ctx, ext_manager) ext_manager.callbacks.pre_unet(ctx, ext_manager)

View File

@ -35,7 +35,7 @@ class PreviewExt(ExtensionBase):
PipelineIntermediateState( PipelineIntermediateState(
step=-1, step=-1,
order=ctx.scheduler.order, order=ctx.scheduler.order,
total_steps=len(ctx.timesteps), total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it? timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
latents=ctx.latents, latents=ctx.latents,
) )
@ -55,7 +55,7 @@ class PreviewExt(ExtensionBase):
PipelineIntermediateState( PipelineIntermediateState(
step=ctx.step_index, step=ctx.step_index,
order=ctx.scheduler.order, order=ctx.scheduler.order,
total_steps=len(ctx.timesteps), total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.timestep), # TODO: is there any code which uses it? timestep=int(ctx.timestep), # TODO: is there any code which uses it?
latents=ctx.step_output.prev_sample, latents=ctx.step_output.prev_sample,
predicted_original=predicted_original, # TODO: is there any reason for additional field? predicted_original=predicted_original, # TODO: is there any reason for additional field?