From 6c8d4b091eb0a8752b252905dd471a298c4dcb9a Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 19 Feb 2023 16:58:54 -0800 Subject: [PATCH] dev(InvokeAIDiffuserComponent): mollify type checker's concern about the optional argument --- ldm/models/diffusion/shared_invokeai_diffusion.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 2e513d3f5a..d962932484 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -6,12 +6,18 @@ from typing import Callable, Optional, Union, Any, Dict import numpy as np import torch from diffusers.models.cross_attention import AttnProcessor +from typing_extensions import TypeAlias from ldm.models.diffusion.cross_attention_control import Arguments, \ restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ CrossAttentionType, SwapCrossAttnContext from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver +ModelForwardCallback: TypeAlias = Union[ + # x, t, conditioning, Optional[cross-attention kwargs] + Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] +] @dataclass(frozen=True) class PostprocessingSettings: @@ -42,8 +48,7 @@ class InvokeAIDiffuserComponent: return self.cross_attention_control_args is not None - def __init__(self, model, model_forward_callback: - Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor], + def __init__(self, model, model_forward_callback: ModelForwardCallback, is_running_diffusers: bool=False, ): """