mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Separate inputs in denoise context
This commit is contained in:
parent
9f088d1bf5
commit
608cbe3f5c
@ -40,7 +40,7 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
@ -768,13 +768,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
denoise_ctx = DenoiseContext(
|
denoise_ctx = DenoiseContext(
|
||||||
latents=latents,
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
init_timestep=init_timestep,
|
init_timestep=init_timestep,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
),
|
||||||
unet=None,
|
unet=None,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
@ -30,8 +30,8 @@ class UNetKwargs:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DenoiseContext:
|
class DenoiseInputs:
|
||||||
latents: torch.Tensor
|
orig_latents: torch.Tensor
|
||||||
scheduler_step_kwargs: dict[str, Any]
|
scheduler_step_kwargs: dict[str, Any]
|
||||||
conditioning_data: TextConditioningData
|
conditioning_data: TextConditioningData
|
||||||
noise: Optional[torch.Tensor]
|
noise: Optional[torch.Tensor]
|
||||||
@ -39,10 +39,15 @@ class DenoiseContext:
|
|||||||
timesteps: torch.Tensor
|
timesteps: torch.Tensor
|
||||||
init_timestep: torch.Tensor
|
init_timestep: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DenoiseContext:
|
||||||
|
inputs: DenoiseInputs
|
||||||
|
|
||||||
scheduler: SchedulerMixin
|
scheduler: SchedulerMixin
|
||||||
unet: Optional[UNet2DConditionModel] = None
|
unet: Optional[UNet2DConditionModel] = None
|
||||||
|
|
||||||
orig_latents: Optional[torch.Tensor] = None
|
latents: Optional[torch.Tensor] = None
|
||||||
step_index: Optional[int] = None
|
step_index: Optional[int] = None
|
||||||
timestep: Optional[torch.Tensor] = None
|
timestep: Optional[torch.Tensor] = None
|
||||||
unet_kwargs: Optional[UNetKwargs] = None
|
unet_kwargs: Optional[UNetKwargs] = None
|
||||||
|
@ -22,25 +22,27 @@ class StableDiffusionBackend:
|
|||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||||
if ctx.init_timestep.shape[0] == 0:
|
if ctx.inputs.init_timestep.shape[0] == 0:
|
||||||
return ctx.latents
|
return ctx.inputs.orig_latents
|
||||||
|
|
||||||
ctx.orig_latents = ctx.latents.clone()
|
ctx.latents = ctx.inputs.orig_latents.clone()
|
||||||
|
|
||||||
if ctx.noise is not None:
|
if ctx.inputs.noise is not None:
|
||||||
batch_size = ctx.latents.shape[0]
|
batch_size = ctx.latents.shape[0]
|
||||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||||
ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size))
|
ctx.latents = ctx.scheduler.add_noise(
|
||||||
|
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
# if no work to do, return latents
|
# if no work to do, return latents
|
||||||
if ctx.timesteps.shape[0] == 0:
|
if ctx.inputs.timesteps.shape[0] == 0:
|
||||||
return ctx.latents
|
return ctx.latents
|
||||||
|
|
||||||
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||||
# ext: preview[pre_denoise_loop, priority=low]
|
# ext: preview[pre_denoise_loop, priority=low]
|
||||||
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
|
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
|
||||||
|
|
||||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020
|
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
||||||
# ext: inpaint (apply mask to latents on non-inpaint models)
|
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||||
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
||||||
|
|
||||||
@ -80,7 +82,7 @@ class StableDiffusionBackend:
|
|||||||
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
|
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
|
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||||
|
|
||||||
# clean up locals
|
# clean up locals
|
||||||
ctx.latent_model_input = None
|
ctx.latent_model_input = None
|
||||||
@ -92,7 +94,7 @@ class StableDiffusionBackend:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||||
guidance_scale = ctx.conditioning_data.guidance_scale
|
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = guidance_scale[ctx.step_index]
|
guidance_scale = guidance_scale[ctx.step_index]
|
||||||
|
|
||||||
@ -109,12 +111,12 @@ class StableDiffusionBackend:
|
|||||||
timestep=ctx.timestep,
|
timestep=ctx.timestep,
|
||||||
encoder_hidden_states=None, # set later by conditoning
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
cross_attention_kwargs=dict( # noqa: C408
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.conditioning_mode = conditioning_mode
|
ctx.conditioning_mode = conditioning_mode
|
||||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||||
|
|
||||||
# ext: controlnet/ip/t2i [pre_unet]
|
# ext: controlnet/ip/t2i [pre_unet]
|
||||||
ext_manager.callbacks.pre_unet(ctx, ext_manager)
|
ext_manager.callbacks.pre_unet(ctx, ext_manager)
|
||||||
|
@ -35,7 +35,7 @@ class PreviewExt(ExtensionBase):
|
|||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
step=-1,
|
step=-1,
|
||||||
order=ctx.scheduler.order,
|
order=ctx.scheduler.order,
|
||||||
total_steps=len(ctx.timesteps),
|
total_steps=len(ctx.inputs.timesteps),
|
||||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||||
latents=ctx.latents,
|
latents=ctx.latents,
|
||||||
)
|
)
|
||||||
@ -55,7 +55,7 @@ class PreviewExt(ExtensionBase):
|
|||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
step=ctx.step_index,
|
step=ctx.step_index,
|
||||||
order=ctx.scheduler.order,
|
order=ctx.scheduler.order,
|
||||||
total_steps=len(ctx.timesteps),
|
total_steps=len(ctx.inputs.timesteps),
|
||||||
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
|
||||||
latents=ctx.step_output.prev_sample,
|
latents=ctx.step_output.prev_sample,
|
||||||
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
predicted_original=predicted_original, # TODO: is there any reason for additional field?
|
||||||
|
Loading…
Reference in New Issue
Block a user