mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
dev(InvokeAIDiffuserComponent): mollify type checker's concern about the optional argument
This commit is contained in:
parent
d0abe13b60
commit
6c8d4b091e
@ -6,12 +6,18 @@ from typing import Callable, Optional, Union, Any, Dict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||||
CrossAttentionType, SwapCrossAttnContext
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], torch.Tensor],
|
||||||
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
|
]
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
@ -42,8 +48,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
return self.cross_attention_control_args is not None
|
return self.cross_attention_control_args is not None
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model, model_forward_callback:
|
def __init__(self, model, model_forward_callback: ModelForwardCallback,
|
||||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str,Any]]], torch.Tensor],
|
|
||||||
is_running_diffusers: bool=False,
|
is_running_diffusers: bool=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user