diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 625fb3c43b..c27ec1e0d9 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, Field, ValidationError from invokeai.app.invocations.baseinvocation import ( MetadataField, - type_adapter_MetadataField, - type_adapter_WorkflowField, + MetadataFieldValidator, + WorkflowFieldValidator, ) from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO @@ -66,7 +66,7 @@ async def upload_image( metadata_raw = pil_image.info.get("invokeai_metadata", None) if metadata_raw: try: - metadata = type_adapter_MetadataField.validate_json(metadata_raw) + metadata = MetadataFieldValidator.validate_json(metadata_raw) except ValidationError: ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") pass @@ -75,7 +75,7 @@ async def upload_image( workflow_raw = pil_image.info.get("invokeai_workflow", None) if workflow_raw is not None: try: - workflow = type_adapter_WorkflowField.validate_json(workflow_raw) + workflow = WorkflowFieldValidator.validate_json(workflow_raw) except ValidationError: ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") pass diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 018f3af02b..afa7d8df82 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -23,13 +23,13 @@ from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -update_models_response_adapter = TypeAdapter(UpdateModelResponse) +UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse) ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -import_models_response_adapter = TypeAdapter(ImportModelResponse) +ImportModelResponseValidator = TypeAdapter(ImportModelResponse) ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] -convert_models_response_adapter = TypeAdapter(ConvertModelResponse) +ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse) MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] @@ -41,7 +41,7 @@ class ModelsList(BaseModel): model_config = ConfigDict(use_enum_values=True) -models_list_adapter = TypeAdapter(ModelsList) +ModelsListValidator = TypeAdapter(ModelsList) @models_router.get( @@ -60,7 +60,7 @@ async def list_models( models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) else: models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) - models = models_list_adapter.validate_python({"models": models_raw}) + models = ModelsListValidator.validate_python({"models": models_raw}) return models @@ -131,7 +131,7 @@ async def update_model( base_model=base_model, model_type=model_type, ) - model_response = update_models_response_adapter.validate_python(model_raw) + model_response = UpdateModelResponseValidator.validate_python(model_raw) except ModelNotFoundException as e: raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: @@ -186,7 +186,7 @@ async def import_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_name=info.name, base_model=info.base_model, model_type=info.model_type ) - return import_models_response_adapter.validate_python(model_raw) + return ImportModelResponseValidator.validate_python(model_raw) except ModelNotFoundException as e: logger.error(str(e)) @@ -231,7 +231,7 @@ async def add_model( base_model=info.base_model, model_type=info.model_type, ) - return import_models_response_adapter.validate_python(model_raw) + return ImportModelResponseValidator.validate_python(model_raw) except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) @@ -302,7 +302,7 @@ async def convert_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_name, base_model=base_model, model_type=model_type ) - response = convert_models_response_adapter.validate_python(model_raw) + response = ConvertModelResponseValidator.validate_python(model_raw) except ModelNotFoundException as e: raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") except ValueError as e: @@ -417,7 +417,7 @@ async def merge_models( base_model=base_model, model_type=ModelType.Main, ) - response = convert_models_response_adapter.validate_python(model_raw) + response = ConvertModelResponseValidator.validate_python(model_raw) except ModelNotFoundException: raise HTTPException( status_code=404, diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 50ce8de7d3..5f1ff0395f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -821,7 +821,7 @@ class WorkflowField(RootModel): root: dict[str, Any] = Field(description="The workflow") -type_adapter_WorkflowField = TypeAdapter(WorkflowField) +WorkflowFieldValidator = TypeAdapter(WorkflowField) class WithWorkflow(BaseModel): @@ -837,7 +837,7 @@ class MetadataField(RootModel): root: dict[str, Any] = Field(description="The metadata") -type_adapter_MetadataField = TypeAdapter(MetadataField) +MetadataFieldValidator = TypeAdapter(MetadataField) class WithMetadata(BaseModel): diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 7b60ec3d5b..dcabe55829 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -3,7 +3,7 @@ import threading from datetime import datetime from typing import Optional, Union, cast -from invokeai.app.invocations.baseinvocation import MetadataField, type_adapter_MetadataField +from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.sqlite import SqliteDatabase @@ -170,7 +170,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): as_dict = dict(result) metadata_raw = cast(Optional[str], as_dict.get("metadata", None)) - return type_adapter_MetadataField.validate_json(metadata_raw) if metadata_raw is not None else None + return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None except sqlite3.Error as e: self._conn.rollback() raise ImageRecordNotFoundException from e diff --git a/invokeai/app/services/item_storage/item_storage_sqlite.py b/invokeai/app/services/item_storage/item_storage_sqlite.py index 1bb9429130..d0249ebfa6 100644 --- a/invokeai/app/services/item_storage/item_storage_sqlite.py +++ b/invokeai/app/services/item_storage/item_storage_sqlite.py @@ -18,7 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): _cursor: sqlite3.Cursor _id_field: str _lock: threading.RLock - _adapter: Optional[TypeAdapter[T]] + _validator: Optional[TypeAdapter[T]] def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"): super().__init__() @@ -28,7 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._table_name = table_name self._id_field = id_field # TODO: validate that T has this field self._cursor = self._conn.cursor() - self._adapter: Optional[TypeAdapter[T]] = None + self._validator: Optional[TypeAdapter[T]] = None self._create_table() @@ -47,14 +47,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._lock.release() def _parse_item(self, item: str) -> T: - if self._adapter is None: + if self._validator is None: """ We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so we can create it when it is first needed instead. __orig_class__ is technically an implementation detail of the typing module, not a supported API """ - self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined] - return self._adapter.validate_json(item) + self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined] + return self._validator.validate_json(item) def set(self, item: T): try: diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index cbf2154b66..69e6a3ab87 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -147,20 +147,20 @@ DEFAULT_QUEUE_ID = "default" QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"] -adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue]) +NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue]) def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]: field_values_raw = queue_item_dict.get("field_values", None) - return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None + return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None -adapter_GraphExecutionState = TypeAdapter(GraphExecutionState) +GraphExecutionStateValidator = TypeAdapter(GraphExecutionState) def get_session(queue_item_dict: dict) -> GraphExecutionState: session_raw = queue_item_dict.get("session", "{}") - session = adapter_GraphExecutionState.validate_json(session_raw, strict=False) + session = GraphExecutionStateValidator.validate_json(session_raw, strict=False) return session diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index e3c11cfa4b..2d9e1f26e8 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -1,7 +1,7 @@ import sqlite3 import threading -from invokeai.app.invocations.baseinvocation import WorkflowField, type_adapter_WorkflowField +from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.sqlite import SqliteDatabase from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase @@ -39,7 +39,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): row = self._cursor.fetchone() if row is None: raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") - return type_adapter_WorkflowField.validate_json(row[0]) + return WorkflowFieldValidator.validate_json(row[0]) except Exception: self._conn.rollback() raise diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index d1ece0336a..e2a50e61e5 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -615,8 +615,8 @@ def test_graph_can_deserialize(): g.add_edge(e) json = g.model_dump_json() - adapter_graph = TypeAdapter(Graph) - g2 = adapter_graph.validate_json(json) + GraphValidator = TypeAdapter(Graph) + g2 = GraphValidator.validate_json(json) assert g2 is not None assert g2.nodes["1"] is not None diff --git a/tests/nodes/test_session_queue.py b/tests/nodes/test_session_queue.py index 731316068c..cdab5729f8 100644 --- a/tests/nodes/test_session_queue.py +++ b/tests/nodes/test_session_queue.py @@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph): values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000) assert len(values) == 8 - session_adapter = TypeAdapter(GraphExecutionState) + GraphExecutionStateValidator = TypeAdapter(GraphExecutionState) # graph should be serialized - ges = session_adapter.validate_json(values[0].session) + ges = GraphExecutionStateValidator.validate_json(values[0].session) # graph values should be populated assert ges.graph.get_node("1").prompt == "Banana sushi" @@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph): assert ges.graph.get_node("4").prompt == "Nissan" # session ids should match deserialized graph - assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values] + assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values] # should unique session ids sids = [v.session_id for v in values] assert len(sids) == len(set(sids)) - nfv_list_adapter = TypeAdapter(list[NodeFieldValue]) + NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue]) # should have 3 node field values assert type(values[0].field_values) is str - assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3 + assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3 # should have batch id and priority assert all(v.batch_id == b.batch_id for v in values)