awkward workaround for double-Annotated in model_record route

This commit is contained in:
Lincoln Stein 2023-11-10 21:32:44 -05:00
parent f2c3b7c317
commit 2b36565e9e
13 changed files with 24 additions and 39 deletions

View File

@ -2,14 +2,13 @@
"""FastAPI route for model configuration records."""
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, TypeAdapter
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
@ -85,8 +84,9 @@ async def get_model_record(
response_model=AnyModelConfig,
)
async def update_model_record(
key: str = Path(description="Unique key of model"),
info: AnyModelConfig = Body(description="Model configuration"),
key: Annotated[str, Path(description="Unique key of model")],
# info: Annotated[AnyModelConfig, Body(description="Model configuration")],
info: AnyModelConfig,
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
@ -134,7 +134,7 @@ async def del_model_record(
status_code=201,
)
async def add_model_record(
config: AnyModelConfig = Body(description="Model configuration"),
config: AnyModelConfig,
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.

View File

@ -755,10 +755,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
(
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
(result_latents, result_attention_map_saver,) = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,

View File

@ -22,6 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
from enum import Enum
from typing import Literal, Optional, Type, Union
from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
@ -268,7 +269,7 @@ AnyModelConfig = Annotated[
CLIPVisionDiffusersConfig,
T2IConfig,
],
Field(discriminator="type"),
Body(discriminator="type"),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)

View File

@ -175,10 +175,7 @@ class InvokeAIDiffuserComponent:
dim=0,
),
}
(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
)
@ -240,10 +237,7 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
sample,
timestep,
conditioning_data,
@ -251,20 +245,14 @@ class InvokeAIDiffuserComponent:
**kwargs,
)
elif self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
sample,
timestep,
conditioning_data,
**kwargs,
)
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
sample,
timestep,
conditioning_data,

View File

@ -470,10 +470,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
(h, w,) = (
img.shape[0],
img.shape[1],
)

View File

@ -16,6 +16,7 @@ from invokeai.app.services.model_records import (
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
MainDiffusersConfig,
ModelType,
TextualInversionConfig,
@ -57,6 +58,7 @@ def test_add(store: ModelRecordServiceBase):
store.add_model("key1", raw)
config1 = store.get_model("key1")
assert config1 is not None
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType("sd-1")
assert config1.name == "model1"
assert config1.original_hash == "111222333444"