mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Separate inputs in denoise context
This commit is contained in:
parent
9f088d1bf5
commit
608cbe3f5c
@ -40,7 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
@ -768,13 +768,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
@ -30,8 +30,8 @@ class UNetKwargs:
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseContext:
|
||||
latents: torch.Tensor
|
||||
class DenoiseInputs:
|
||||
orig_latents: torch.Tensor
|
||||
scheduler_step_kwargs: dict[str, Any]
|
||||
conditioning_data: TextConditioningData
|
||||
noise: Optional[torch.Tensor]
|
||||
@ -39,10 +39,15 @@ class DenoiseContext:
|
||||
timesteps: torch.Tensor
|
||||
init_timestep: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseContext:
|
||||
inputs: DenoiseInputs
|
||||
|
||||
scheduler: SchedulerMixin
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
orig_latents: Optional[torch.Tensor] = None
|
||||
latents: Optional[torch.Tensor] = None
|
||||
step_index: Optional[int] = None
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
@ -22,25 +22,27 @@ class StableDiffusionBackend:
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
|
||||
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
if ctx.init_timestep.shape[0] == 0:
|
||||
return ctx.latents
|
||||
if ctx.inputs.init_timestep.shape[0] == 0:
|
||||
return ctx.inputs.orig_latents
|
||||
|
||||
ctx.orig_latents = ctx.latents.clone()
|
||||
ctx.latents = ctx.inputs.orig_latents.clone()
|
||||
|
||||
if ctx.noise is not None:
|
||||
if ctx.inputs.noise is not None:
|
||||
batch_size = ctx.latents.shape[0]
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size))
|
||||
ctx.latents = ctx.scheduler.add_noise(
|
||||
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
|
||||
)
|
||||
|
||||
# if no work to do, return latents
|
||||
if ctx.timesteps.shape[0] == 0:
|
||||
if ctx.inputs.timesteps.shape[0] == 0:
|
||||
return ctx.latents
|
||||
|
||||
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||
# ext: preview[pre_denoise_loop, priority=low]
|
||||
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
|
||||
|
||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
|
||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
||||
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
||||
|
||||
@ -80,7 +82,7 @@ class StableDiffusionBackend:
|
||||
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||
|
||||
# clean up locals
|
||||
ctx.latent_model_input = None
|
||||
@ -92,7 +94,7 @@ class StableDiffusionBackend:
|
||||
|
||||
@staticmethod
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.conditioning_data.guidance_scale
|
||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
|
||||
@ -109,12 +111,12 @@ class StableDiffusionBackend:
|
||||
timestep=ctx.timestep,
|
||||
encoder_hidden_states=None, # set later by conditoning
|
||||
cross_attention_kwargs=dict( # noqa: C408
|
||||
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
||||
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
|
||||
),
|
||||
)
|
||||
|
||||
ctx.conditioning_mode = conditioning_mode
|
||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||
|
||||
# ext: controlnet/ip/t2i [pre_unet]
|
||||
ext_manager.callbacks.pre_unet(ctx, ext_manager)
|
||||
|
@ -35,7 +35,7 @@ class PreviewExt(ExtensionBase):
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.timesteps),
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||
latents=ctx.latents,
|
||||
)
|
||||
@ -55,7 +55,7 @@ class PreviewExt(ExtensionBase):
|
||||
PipelineIntermediateState(
|
||||
step=ctx.step_index,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.timesteps),
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||
latents=ctx.step_output.prev_sample,
|
||||
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||
|
Loading…
Reference in New Issue
Block a user