awkward workaround for double-Annotated in model_record route

This commit is contained in:
Lincoln Stein 2023-11-10 21:32:44 -05:00
parent f2c3b7c317
commit 2b36565e9e
13 changed files with 24 additions and 39 deletions

View File

@ -2,14 +2,13 @@
"""FastAPI route for model configuration records.""" """FastAPI route for model configuration records."""
from hashlib import sha1
from random import randbytes
from typing import List, Optional from typing import List, Optional
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, TypeAdapter from pydantic import BaseModel, ConfigDict, TypeAdapter
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
@ -85,8 +84,9 @@ async def get_model_record(
response_model=AnyModelConfig, response_model=AnyModelConfig,
) )
async def update_model_record( async def update_model_record(
key: str = Path(description="Unique key of model"), key: Annotated[str, Path(description="Unique key of model")],
info: AnyModelConfig = Body(description="Model configuration"), # info: Annotated[AnyModelConfig, Body(description="Model configuration")],
info: AnyModelConfig,
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -134,7 +134,7 @@ async def del_model_record(
status_code=201, status_code=201,
) )
async def add_model_record( async def add_model_record(
config: AnyModelConfig = Body(description="Model configuration"), config: AnyModelConfig,
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Add a model using the configuration information appropriate for its type. Add a model using the configuration information appropriate for its type.

View File

@ -755,10 +755,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
( (result_latents, result_attention_map_saver,) = pipeline.latents_from_embeddings(
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents, latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,

View File

@ -207,7 +207,7 @@ class IntegerMathInvocation(BaseInvocation):
elif self.operation == "DIV": elif self.operation == "DIV":
return IntegerOutput(value=int(self.a / self.b)) return IntegerOutput(value=int(self.a / self.b))
elif self.operation == "EXP": elif self.operation == "EXP":
return IntegerOutput(value=self.a**self.b) return IntegerOutput(value=self.a ** self.b)
elif self.operation == "MOD": elif self.operation == "MOD":
return IntegerOutput(value=self.a % self.b) return IntegerOutput(value=self.a % self.b)
elif self.operation == "ABS": elif self.operation == "ABS":
@ -281,7 +281,7 @@ class FloatMathInvocation(BaseInvocation):
elif self.operation == "DIV": elif self.operation == "DIV":
return FloatOutput(value=self.a / self.b) return FloatOutput(value=self.a / self.b)
elif self.operation == "EXP": elif self.operation == "EXP":
return FloatOutput(value=self.a**self.b) return FloatOutput(value=self.a ** self.b)
elif self.operation == "SQRT": elif self.operation == "SQRT":
return FloatOutput(value=np.sqrt(self.a)) return FloatOutput(value=np.sqrt(self.a))
elif self.operation == "ABS": elif self.operation == "ABS":

View File

@ -33,7 +33,7 @@ def reshape_tensor(x, heads):
class PerceiverAttention(nn.Module): class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8): def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__() super().__init__()
self.scale = dim_head**-0.5 self.scale = dim_head ** -0.5
self.dim_head = dim_head self.dim_head = dim_head
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@ -91,7 +91,7 @@ class Resampler(nn.Module):
): ):
super().__init__() super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
self.proj_in = nn.Linear(embedding_dim, dim) self.proj_in = nn.Linear(embedding_dim, dim)

View File

@ -6,7 +6,7 @@ import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB GB = 2 ** 30 # 1 GB
class MemorySnapshot: class MemorySnapshot:

View File

@ -49,7 +49,7 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
# Size of a MB in bytes. # Size of a MB in bytes.
MB = 2**20 MB = 2 ** 20
@dataclass @dataclass

View File

@ -22,6 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated from typing_extensions import Annotated
@ -268,7 +269,7 @@ AnyModelConfig = Annotated[
CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig,
T2IConfig, T2IConfig,
], ],
Field(discriminator="type"), Body(discriminator="type"),
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)

View File

@ -261,7 +261,7 @@ class InvokeAICrossAttentionMixin:
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None) return self.einsum_lowest_level(q, k, v, None, None, None)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size) return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v): def einsum_op_mps_v2(self, q, k, v):

View File

@ -175,10 +175,7 @@ class InvokeAIDiffuserComponent:
dim=0, dim=0,
), ),
} }
( (encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds, conditioning_data.text_embeddings.embeds,
) )
@ -240,10 +237,7 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control:
( (unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -251,20 +245,14 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
**kwargs, **kwargs,
) )
else: else:
( (unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,

View File

@ -470,10 +470,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
( (h, w,) = (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )

View File

@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
if attn.upcast_attention: if attn.upcast_attention:
out_item_size = 4 out_item_size = 4
chunk_size = 2**29 chunk_size = 2 ** 29
out_size = query.shape[1] * key.shape[1] * out_item_size out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size)) chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))

View File

@ -210,7 +210,7 @@ def parallel_data_prefetch(
return gather_res return gather_res
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
delta = (res[0] / shape[0], res[1] / shape[1]) delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1]) d = (shape[0] // res[0], shape[1] // res[1])

View File

@ -16,6 +16,7 @@ from invokeai.app.services.model_records import (
from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
MainCheckpointConfig,
MainDiffusersConfig, MainDiffusersConfig,
ModelType, ModelType,
TextualInversionConfig, TextualInversionConfig,
@ -57,6 +58,7 @@ def test_add(store: ModelRecordServiceBase):
store.add_model("key1", raw) store.add_model("key1", raw)
config1 = store.get_model("key1") config1 = store.get_model("key1")
assert config1 is not None assert config1 is not None
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType("sd-1") assert config1.base == BaseModelType("sd-1")
assert config1.name == "model1" assert config1.name == "model1"
assert config1.original_hash == "111222333444" assert config1.original_hash == "111222333444"