mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
awkward workaround for double-Annotated in model_record route
This commit is contained in:
parent
f2c3b7c317
commit
2b36565e9e
@ -2,14 +2,13 @@
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_records import DuplicateModelException, InvalidModelException, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
@ -85,8 +84,9 @@ async def get_model_record(
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: str = Path(description="Unique key of model"),
|
||||
info: AnyModelConfig = Body(description="Model configuration"),
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
# info: Annotated[AnyModelConfig, Body(description="Model configuration")],
|
||||
info: AnyModelConfig,
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@ -134,7 +134,7 @@ async def del_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: AnyModelConfig = Body(description="Model configuration"),
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
|
@ -755,10 +755,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
(result_latents, result_attention_map_saver,) = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
|
@ -22,6 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from fastapi import Body
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@ -268,7 +269,7 @@ AnyModelConfig = Annotated[
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
Body(discriminator="type"),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
@ -175,10 +175,7 @@ class InvokeAIDiffuserComponent:
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
)
|
||||
@ -240,10 +237,7 @@ class InvokeAIDiffuserComponent:
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
@ -251,20 +245,14 @@ class InvokeAIDiffuserComponent:
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
|
@ -470,10 +470,7 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
(h, w,) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.model_records import (
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
@ -57,6 +58,7 @@ def test_add(store: ModelRecordServiceBase):
|
||||
store.add_model("key1", raw)
|
||||
config1 = store.get_model("key1")
|
||||
assert config1 is not None
|
||||
assert type(config1) == MainCheckpointConfig
|
||||
assert config1.base == BaseModelType("sd-1")
|
||||
assert config1.name == "model1"
|
||||
assert config1.original_hash == "111222333444"
|
||||
|
Loading…
x
Reference in New Issue
Block a user