This commit is contained in:
Sergey Borisov 2024-07-12 22:44:00 +03:00
parent bd8ae5d896
commit 3a9dda9177
4 changed files with 19 additions and 9 deletions

View File

@ -781,7 +781,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet_config = context.models.get_config(self.unet.unet.key)
# ext: t2i/ip adapter
ext_manager.callbacks.pre_unet_load(denoise_ctx, ext_manager)
ext_manager.callbacks.setup(denoise_ctx, ext_manager)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)

View File

@ -129,6 +129,7 @@ class TextConditioningData:
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
# TODO: combine regions with conditionings
if conditioning_mode == "both":
conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
c_regions = [self.uncond_regions, self.cond_regions]

View File

@ -61,7 +61,10 @@ class StableDiffusionBackend:
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
# TODO: conditionings as list
# TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready)
# Note: The current handling of conditioning doesn't feel very future-proof.
# This might change in the future as new requirements come up, but for now,
# this is the rough plan.
if self.sequential_guidance:
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, "negative")
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, "positive")
@ -74,7 +77,7 @@ class StableDiffusionBackend:
# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
ext_manager.callbacks.modify_noise_prediction(ctx, ext_manager)
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs)
@ -113,14 +116,16 @@ class StableDiffusionBackend:
ctx.conditioning_mode = conditioning_mode
ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet_forward]
ext_manager.callbacks.pre_unet_forward(ctx, ext_manager)
# ext: controlnet/ip/t2i [pre_unet]
ext_manager.callbacks.pre_unet(ctx, ext_manager)
# ext: inpaint [pre_unet_forward, priority=low]
# ext: inpaint [pre_unet, priority=low]
# or
# ext: inpaint [override: unet_forward]
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
ext_manager.callbacks.post_unet(ctx, ext_manager)
del ctx.unet_kwargs
del ctx.conditioning_mode

View File

@ -16,6 +16,10 @@ if TYPE_CHECKING:
class ExtCallbacksApi(ABC):
@abstractmethod
def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@ -33,15 +37,15 @@ class ExtCallbacksApi(ABC):
pass
@abstractmethod
def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass