mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Simplify guidance modes
This commit is contained in:
parent
87e96e1be2
commit
bd8ae5d896
@ -61,18 +61,19 @@ class StableDiffusionBackend:
|
|||||||
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
|
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
|
||||||
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
|
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
# TODO: conditionings as list
|
||||||
if self.sequential_guidance:
|
if self.sequential_guidance:
|
||||||
conditioning_call = self._apply_standard_conditioning_sequentially
|
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative")
|
||||||
|
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive")
|
||||||
else:
|
else:
|
||||||
conditioning_call = self._apply_standard_conditioning
|
both_noise_pred = self.run_unet(ctx, ext_manager, "both")
|
||||||
|
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||||
# not sure if here needed override
|
|
||||||
ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager)
|
|
||||||
|
|
||||||
# ext: override apply_cfg
|
# ext: override apply_cfg
|
||||||
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx)
|
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx)
|
||||||
|
|
||||||
# ext: cfg_rescale [modify_noise_prediction]
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
|
# TODO: rename
|
||||||
ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager)
|
ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -95,15 +96,13 @@ class StableDiffusionBackend:
|
|||||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: str):
|
||||||
self, ctx: DenoiseContext, ext_manager: ExtensionsManager
|
sample = ctx.latent_model_input
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
if conditioning_mode == "both":
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
sample = torch.cat([sample] * 2)
|
||||||
the cost of higher memory usage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ctx.unet_kwargs = UNetKwargs(
|
ctx.unet_kwargs = UNetKwargs(
|
||||||
sample=torch.cat([ctx.latent_model_input] * 2),
|
sample=sample,
|
||||||
timestep=ctx.timestep,
|
timestep=ctx.timestep,
|
||||||
encoder_hidden_states=None, # set later by conditoning
|
encoder_hidden_states=None, # set later by conditoning
|
||||||
cross_attention_kwargs=dict( # noqa: C408
|
cross_attention_kwargs=dict( # noqa: C408
|
||||||
@ -111,7 +110,7 @@ class StableDiffusionBackend:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.conditioning_mode = "both"
|
ctx.conditioning_mode = conditioning_mode
|
||||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||||
|
|
||||||
# ext: controlnet/ip/t2i [pre_unet_forward]
|
# ext: controlnet/ip/t2i [pre_unet_forward]
|
||||||
@ -120,75 +119,12 @@ class StableDiffusionBackend:
|
|||||||
# ext: inpaint [pre_unet_forward, priority=low]
|
# ext: inpaint [pre_unet_forward, priority=low]
|
||||||
# or
|
# or
|
||||||
# ext: inpaint [override: unet_forward]
|
# ext: inpaint [override: unet_forward]
|
||||||
both_results = self._unet_forward(**vars(ctx.unet_kwargs))
|
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||||
negative_next_x, positive_next_x = both_results.chunk(2)
|
|
||||||
# del locals
|
|
||||||
del ctx.unet_kwargs
|
|
||||||
del ctx.conditioning_mode
|
|
||||||
return negative_next_x, positive_next_x
|
|
||||||
|
|
||||||
def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
|
||||||
slower execution speed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
###################
|
|
||||||
# Negative pass
|
|
||||||
###################
|
|
||||||
|
|
||||||
ctx.unet_kwargs = UNetKwargs(
|
|
||||||
sample=ctx.latent_model_input,
|
|
||||||
timestep=ctx.timestep,
|
|
||||||
encoder_hidden_states=None, # set later by conditoning
|
|
||||||
cross_attention_kwargs=dict( # noqa: C408
|
|
||||||
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.conditioning_mode = "negative"
|
|
||||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative")
|
|
||||||
|
|
||||||
# ext: controlnet/ip/t2i [pre_unet_forward]
|
|
||||||
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
|
|
||||||
|
|
||||||
# ext: inpaint [pre_unet_forward, priority=low]
|
|
||||||
# or
|
|
||||||
# ext: inpaint [override: unet_forward]
|
|
||||||
negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
|
|
||||||
|
|
||||||
del ctx.unet_kwargs
|
del ctx.unet_kwargs
|
||||||
del ctx.conditioning_mode
|
del ctx.conditioning_mode
|
||||||
# TODO: gc.collect() ?
|
|
||||||
|
|
||||||
###################
|
return noise_pred
|
||||||
# Positive pass
|
|
||||||
###################
|
|
||||||
|
|
||||||
ctx.unet_kwargs = UNetKwargs(
|
|
||||||
sample=ctx.latent_model_input,
|
|
||||||
timestep=ctx.timestep,
|
|
||||||
encoder_hidden_states=None, # set later by conditoning
|
|
||||||
cross_attention_kwargs=dict( # noqa: C408
|
|
||||||
percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.conditioning_mode = "positive"
|
|
||||||
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive")
|
|
||||||
|
|
||||||
# ext: controlnet/ip/t2i [pre_unet_forward]
|
|
||||||
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
|
|
||||||
|
|
||||||
# ext: inpaint [pre_unet_forward, priority=low]
|
|
||||||
# or
|
|
||||||
# ext: inpaint [override: unet_forward]
|
|
||||||
positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs))
|
|
||||||
|
|
||||||
del ctx.unet_kwargs
|
|
||||||
del ctx.conditioning_mode
|
|
||||||
# TODO: gc.collect() ?
|
|
||||||
|
|
||||||
return negative_next_x, positive_next_x
|
|
||||||
|
|
||||||
def _unet_forward(self, **kwargs) -> torch.Tensor:
|
def _unet_forward(self, **kwargs) -> torch.Tensor:
|
||||||
return self.unet(**kwargs).sample
|
return self.unet(**kwargs).sample
|
||||||
|
Loading…
x
Reference in New Issue
Block a user