awkward workaround for double-Annotated in model_record route

This commit is contained in:
Lincoln Stein
2023-11-10 21:32:44 -05:00
parent f2c3b7c317
commit 2b36565e9e
13 changed files with 24 additions and 39 deletions

View File

@ -33,7 +33,7 @@ def reshape_tensor(x, heads):
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
@ -91,7 +91,7 @@ class Resampler(nn.Module):
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
self.proj_in = nn.Linear(embedding_dim, dim)

View File

@ -6,7 +6,7 @@ import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB
GB = 2 ** 30 # 1 GB
class MemorySnapshot:

View File

@ -49,7 +49,7 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
MB = 2 ** 20
@dataclass

View File

@ -22,6 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
from enum import Enum
from typing import Literal, Optional, Type, Union
from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
@ -268,7 +269,7 @@ AnyModelConfig = Annotated[
CLIPVisionDiffusersConfig,
T2IConfig,
],
Field(discriminator="type"),
Body(discriminator="type"),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)

View File

@ -261,7 +261,7 @@ class InvokeAICrossAttentionMixin:
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):

View File

@ -175,10 +175,7 @@ class InvokeAIDiffuserComponent:
dim=0,
),
}
(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
)
@ -240,10 +237,7 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
sample,
timestep,
conditioning_data,
@ -251,20 +245,14 @@ class InvokeAIDiffuserComponent:
**kwargs,
)
elif self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
sample,
timestep,
conditioning_data,
**kwargs,
)
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
sample,
timestep,
conditioning_data,

View File

@ -470,10 +470,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
(h, w,) = (
img.shape[0],
img.shape[1],
)

View File

@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
if attn.upcast_attention:
out_item_size = 4
chunk_size = 2**29
chunk_size = 2 ** 29
out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))

View File

@ -210,7 +210,7 @@ def parallel_data_prefetch(
return gather_res
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])