add note about discriminated union and Body() issue; blackified

This commit is contained in:
Lincoln Stein 2023-11-12 16:50:05 -05:00
parent ef8dcf5fae
commit 8afe517204
11 changed files with 41 additions and 18 deletions

View File

@ -755,7 +755,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
(result_latents, result_attention_map_saver,) = pipeline.latents_from_embeddings( (
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents, latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,

View File

@ -207,7 +207,7 @@ class IntegerMathInvocation(BaseInvocation):
elif self.operation == "DIV": elif self.operation == "DIV":
return IntegerOutput(value=int(self.a / self.b)) return IntegerOutput(value=int(self.a / self.b))
elif self.operation == "EXP": elif self.operation == "EXP":
return IntegerOutput(value=self.a ** self.b) return IntegerOutput(value=self.a**self.b)
elif self.operation == "MOD": elif self.operation == "MOD":
return IntegerOutput(value=self.a % self.b) return IntegerOutput(value=self.a % self.b)
elif self.operation == "ABS": elif self.operation == "ABS":
@ -281,7 +281,7 @@ class FloatMathInvocation(BaseInvocation):
elif self.operation == "DIV": elif self.operation == "DIV":
return FloatOutput(value=self.a / self.b) return FloatOutput(value=self.a / self.b)
elif self.operation == "EXP": elif self.operation == "EXP":
return FloatOutput(value=self.a ** self.b) return FloatOutput(value=self.a**self.b)
elif self.operation == "SQRT": elif self.operation == "SQRT":
return FloatOutput(value=np.sqrt(self.a)) return FloatOutput(value=np.sqrt(self.a))
elif self.operation == "ABS": elif self.operation == "ABS":

View File

@ -33,7 +33,7 @@ def reshape_tensor(x, heads):
class PerceiverAttention(nn.Module): class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8): def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.dim_head = dim_head self.dim_head = dim_head
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@ -91,7 +91,7 @@ class Resampler(nn.Module):
): ):
super().__init__() super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim) self.proj_in = nn.Linear(embedding_dim, dim)

View File

@ -6,7 +6,7 @@ import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2 ** 30 # 1 GB GB = 2**30 # 1 GB
class MemorySnapshot: class MemorySnapshot:

View File

@ -49,7 +49,7 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
# Size of a MB in bytes. # Size of a MB in bytes.
MB = 2 ** 20 MB = 2**20
@dataclass @dataclass

View File

@ -268,7 +268,14 @@ AnyModelConfig = Union[
T2IConfig, T2IConfig,
] ]
# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route. AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[ # AnyModelConfig = Annotated[
# Union[ # Union[
# _MainModelConfig, # _MainModelConfig,
@ -284,8 +291,6 @@ AnyModelConfig = Union[
# Field(discriminator="type"), # Field(discriminator="type"),
# ] # ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""

View File

@ -261,7 +261,7 @@ class InvokeAICrossAttentionMixin:
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None) return self.einsum_lowest_level(q, k, v, None, None, None)
else: else:
slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size) return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v): def einsum_op_mps_v2(self, q, k, v):

View File

@ -175,7 +175,10 @@ class InvokeAIDiffuserComponent:
dim=0, dim=0,
), ),
} }
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch( (
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds, conditioning_data.text_embeddings.embeds,
) )
@ -237,7 +240,10 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control:
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -245,14 +251,20 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
**kwargs, **kwargs,
) )
else: else:
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning( (
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,

View File

@ -470,7 +470,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
(h, w,) = ( (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )

View File

@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
if attn.upcast_attention: if attn.upcast_attention:
out_item_size = 4 out_item_size = 4
chunk_size = 2 ** 29 chunk_size = 2**29
out_size = query.shape[1] * key.shape[1] * out_item_size out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size)) chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))

View File

@ -210,7 +210,7 @@ def parallel_data_prefetch(
return gather_res return gather_res
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1]) delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1]) d = (shape[0] // res[0], shape[1] // res[1])