mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
awkward workaround for double-Annotated in model_record route
This commit is contained in:
parent
f2c3b7c317
commit
2b36565e9e
@ -2,14 +2,13 @@
|
|||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
|
|
||||||
from hashlib import sha1
|
|
||||||
from random import randbytes
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
|
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||||
@ -85,8 +84,9 @@ async def get_model_record(
|
|||||||
response_model=AnyModelConfig,
|
response_model=AnyModelConfig,
|
||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: str = Path(description="Unique key of model"),
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
info: AnyModelConfig = Body(description="Model configuration"),
|
# info: Annotated[AnyModelConfig, Body(description="Model configuration")],
|
||||||
|
info: AnyModelConfig,
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
@ -134,7 +134,7 @@ async def del_model_record(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def add_model_record(
|
async def add_model_record(
|
||||||
config: AnyModelConfig = Body(description="Model configuration"),
|
config: AnyModelConfig,
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model using the configuration information appropriate for its type.
|
Add a model using the configuration information appropriate for its type.
|
||||||
|
@ -755,10 +755,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(result_latents, result_attention_map_saver,) = pipeline.latents_from_embeddings(
|
||||||
result_latents,
|
|
||||||
result_attention_map_saver,
|
|
||||||
) = pipeline.latents_from_embeddings(
|
|
||||||
latents=latents,
|
latents=latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
init_timestep=init_timestep,
|
init_timestep=init_timestep,
|
||||||
|
@ -207,7 +207,7 @@ class IntegerMathInvocation(BaseInvocation):
|
|||||||
elif self.operation == "DIV":
|
elif self.operation == "DIV":
|
||||||
return IntegerOutput(value=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
elif self.operation == "EXP":
|
elif self.operation == "EXP":
|
||||||
return IntegerOutput(value=self.a**self.b)
|
return IntegerOutput(value=self.a ** self.b)
|
||||||
elif self.operation == "MOD":
|
elif self.operation == "MOD":
|
||||||
return IntegerOutput(value=self.a % self.b)
|
return IntegerOutput(value=self.a % self.b)
|
||||||
elif self.operation == "ABS":
|
elif self.operation == "ABS":
|
||||||
@ -281,7 +281,7 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
elif self.operation == "DIV":
|
elif self.operation == "DIV":
|
||||||
return FloatOutput(value=self.a / self.b)
|
return FloatOutput(value=self.a / self.b)
|
||||||
elif self.operation == "EXP":
|
elif self.operation == "EXP":
|
||||||
return FloatOutput(value=self.a**self.b)
|
return FloatOutput(value=self.a ** self.b)
|
||||||
elif self.operation == "SQRT":
|
elif self.operation == "SQRT":
|
||||||
return FloatOutput(value=np.sqrt(self.a))
|
return FloatOutput(value=np.sqrt(self.a))
|
||||||
elif self.operation == "ABS":
|
elif self.operation == "ABS":
|
||||||
|
@ -33,7 +33,7 @@ def reshape_tensor(x, heads):
|
|||||||
class PerceiverAttention(nn.Module):
|
class PerceiverAttention(nn.Module):
|
||||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head**-0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
@ -91,7 +91,7 @@ class Resampler(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
||||||
|
|
||||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||||
|
|
||||||
GB = 2**30 # 1 GB
|
GB = 2 ** 30 # 1 GB
|
||||||
|
|
||||||
|
|
||||||
class MemorySnapshot:
|
class MemorySnapshot:
|
||||||
|
@ -49,7 +49,7 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
|||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
# Size of a MB in bytes.
|
# Size of a MB in bytes.
|
||||||
MB = 2**20
|
MB = 2 ** 20
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -22,6 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
from fastapi import Body
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
@ -268,7 +269,7 @@ AnyModelConfig = Annotated[
|
|||||||
CLIPVisionDiffusersConfig,
|
CLIPVisionDiffusersConfig,
|
||||||
T2IConfig,
|
T2IConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Body(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
@ -261,7 +261,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||||
else:
|
else:
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1]))
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
def einsum_op_mps_v2(self, q, k, v):
|
||||||
|
@ -175,10 +175,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
(
|
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
|
||||||
encoder_hidden_states,
|
|
||||||
encoder_attention_mask,
|
|
||||||
) = self._concat_conditionings_for_batch(
|
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.text_embeddings.embeds,
|
||||||
)
|
)
|
||||||
@ -240,10 +237,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
|
|
||||||
if wants_cross_attention_control:
|
if wants_cross_attention_control:
|
||||||
(
|
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
|
||||||
unconditioned_next_x,
|
|
||||||
conditioned_next_x,
|
|
||||||
) = self._apply_cross_attention_controlled_conditioning(
|
|
||||||
sample,
|
sample,
|
||||||
timestep,
|
timestep,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
@ -251,20 +245,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(
|
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
|
||||||
unconditioned_next_x,
|
|
||||||
conditioned_next_x,
|
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
|
||||||
sample,
|
sample,
|
||||||
timestep,
|
timestep,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
(
|
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
|
||||||
unconditioned_next_x,
|
|
||||||
conditioned_next_x,
|
|
||||||
) = self._apply_standard_conditioning(
|
|
||||||
sample,
|
sample,
|
||||||
timestep,
|
timestep,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
|
@ -470,10 +470,7 @@ class TextualInversionDataset(Dataset):
|
|||||||
|
|
||||||
if self.center_crop:
|
if self.center_crop:
|
||||||
crop = min(img.shape[0], img.shape[1])
|
crop = min(img.shape[0], img.shape[1])
|
||||||
(
|
(h, w,) = (
|
||||||
h,
|
|
||||||
w,
|
|
||||||
) = (
|
|
||||||
img.shape[0],
|
img.shape[0],
|
||||||
img.shape[1],
|
img.shape[1],
|
||||||
)
|
)
|
||||||
|
@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
|
|||||||
if attn.upcast_attention:
|
if attn.upcast_attention:
|
||||||
out_item_size = 4
|
out_item_size = 4
|
||||||
|
|
||||||
chunk_size = 2**29
|
chunk_size = 2 ** 29
|
||||||
|
|
||||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||||
|
@ -210,7 +210,7 @@ def parallel_data_prefetch(
|
|||||||
return gather_res
|
return gather_res
|
||||||
|
|
||||||
|
|
||||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
||||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
d = (shape[0] // res[0], shape[1] // res[1])
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.model_records import (
|
|||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
MainCheckpointConfig,
|
||||||
MainDiffusersConfig,
|
MainDiffusersConfig,
|
||||||
ModelType,
|
ModelType,
|
||||||
TextualInversionConfig,
|
TextualInversionConfig,
|
||||||
@ -57,6 +58,7 @@ def test_add(store: ModelRecordServiceBase):
|
|||||||
store.add_model("key1", raw)
|
store.add_model("key1", raw)
|
||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert config1 is not None
|
assert config1 is not None
|
||||||
|
assert type(config1) == MainCheckpointConfig
|
||||||
assert config1.base == BaseModelType("sd-1")
|
assert config1.base == BaseModelType("sd-1")
|
||||||
assert config1.name == "model1"
|
assert config1.name == "model1"
|
||||||
assert config1.original_hash == "111222333444"
|
assert config1.original_hash == "111222333444"
|
||||||
|
Loading…
Reference in New Issue
Block a user