dev(InvokeAIDiffuserComponent): mollify type checker's concern about the optional argument

This commit is contained in:
Kevin Turner 2023-02-19 16:58:54 -08:00
parent d0abe13b60
commit 6c8d4b091e

View File

@ -6,12 +6,18 @@ from typing import Callable, Optional, Union, Any, Dict
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias
from ldm.models.diffusion.cross_attention_control import Arguments, \ from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext CrossAttentionType, SwapCrossAttnContext
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
]
@dataclass(frozen=True) @dataclass(frozen=True)
class PostprocessingSettings: class PostprocessingSettings:
@ -42,8 +48,7 @@ class InvokeAIDiffuserComponent:
return self.cross_attention_control_args is not None return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback: def __init__(self, model, model_forward_callback: ModelForwardCallback,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
is_running_diffusers: bool=False, is_running_diffusers: bool=False,
): ):
""" """