mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: use ModelValidator
naming convention for pydantic type adapters
This is the naming convention in the docs and is also clear.
This commit is contained in:
parent
3c4f43314c
commit
4012388f0a
@ -9,8 +9,8 @@ from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
MetadataField,
|
||||
type_adapter_MetadataField,
|
||||
type_adapter_WorkflowField,
|
||||
MetadataFieldValidator,
|
||||
WorkflowFieldValidator,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
@ -66,7 +66,7 @@ async def upload_image(
|
||||
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
||||
if metadata_raw:
|
||||
try:
|
||||
metadata = type_adapter_MetadataField.validate_json(metadata_raw)
|
||||
metadata = MetadataFieldValidator.validate_json(metadata_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
@ -75,7 +75,7 @@ async def upload_image(
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if workflow_raw is not None:
|
||||
try:
|
||||
workflow = type_adapter_WorkflowField.validate_json(workflow_raw)
|
||||
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
@ -23,13 +23,13 @@ from ..dependencies import ApiDependencies
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
||||
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
||||
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
||||
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
||||
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
||||
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
||||
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
@ -41,7 +41,7 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
models_list_adapter = TypeAdapter(ModelsList)
|
||||
ModelsListValidator = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -60,7 +60,7 @@ async def list_models(
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = models_list_adapter.validate_python({"models": models_raw})
|
||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ async def update_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = update_models_response_adapter.validate_python(model_raw)
|
||||
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -186,7 +186,7 @@ async def import_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
@ -231,7 +231,7 @@ async def add_model(
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type,
|
||||
)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -302,7 +302,7 @@ async def convert_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
@ -417,7 +417,7 @@ async def merge_models(
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
@ -821,7 +821,7 @@ class WorkflowField(RootModel):
|
||||
root: dict[str, Any] = Field(description="The workflow")
|
||||
|
||||
|
||||
type_adapter_WorkflowField = TypeAdapter(WorkflowField)
|
||||
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
@ -837,7 +837,7 @@ class MetadataField(RootModel):
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
type_adapter_MetadataField = TypeAdapter(MetadataField)
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
|
@ -3,7 +3,7 @@ import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, type_adapter_MetadataField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
|
||||
@ -170,7 +170,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return type_adapter_MetadataField.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
@ -18,7 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: threading.RLock
|
||||
_adapter: Optional[TypeAdapter[T]]
|
||||
_validator: Optional[TypeAdapter[T]]
|
||||
|
||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
@ -28,7 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._cursor = self._conn.cursor()
|
||||
self._adapter: Optional[TypeAdapter[T]] = None
|
||||
self._validator: Optional[TypeAdapter[T]] = None
|
||||
|
||||
self._create_table()
|
||||
|
||||
@ -47,14 +47,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
if self._adapter is None:
|
||||
if self._validator is None:
|
||||
"""
|
||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||
we can create it when it is first needed instead.
|
||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
"""
|
||||
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._adapter.validate_json(item)
|
||||
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._validator.validate_json(item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
@ -147,20 +147,20 @@ DEFAULT_QUEUE_ID = "default"
|
||||
|
||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||
|
||||
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||
|
||||
|
||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||
field_values_raw = queue_item_dict.get("field_values", None)
|
||||
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||
|
||||
|
||||
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||
|
||||
|
||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
session_raw = queue_item_dict.get("session", "{}")
|
||||
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
||||
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
|
||||
return session
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField, type_adapter_WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
@ -39,7 +39,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return type_adapter_WorkflowField.validate_json(row[0])
|
||||
return WorkflowFieldValidator.validate_json(row[0])
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
@ -615,8 +615,8 @@ def test_graph_can_deserialize():
|
||||
g.add_edge(e)
|
||||
|
||||
json = g.model_dump_json()
|
||||
adapter_graph = TypeAdapter(Graph)
|
||||
g2 = adapter_graph.validate_json(json)
|
||||
GraphValidator = TypeAdapter(Graph)
|
||||
g2 = GraphValidator.validate_json(json)
|
||||
|
||||
assert g2 is not None
|
||||
assert g2.nodes["1"] is not None
|
||||
|
@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||
assert len(values) == 8
|
||||
|
||||
session_adapter = TypeAdapter(GraphExecutionState)
|
||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||
# graph should be serialized
|
||||
ges = session_adapter.validate_json(values[0].session)
|
||||
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
||||
|
||||
# graph values should be populated
|
||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||
@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||
|
||||
# session ids should match deserialized graph
|
||||
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
||||
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
||||
|
||||
# should unique session ids
|
||||
sids = [v.session_id for v in values]
|
||||
assert len(sids) == len(set(sids))
|
||||
|
||||
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||
# should have 3 node field values
|
||||
assert type(values[0].field_values) is str
|
||||
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
||||
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
||||
|
||||
# should have batch id and priority
|
||||
assert all(v.batch_id == b.batch_id for v in values)
|
||||
|
Loading…
x
Reference in New Issue
Block a user