feat: use ModelValidator naming convention for pydantic type adapters

This is the naming convention in the docs and is also clear.
This commit is contained in:
psychedelicious 2023-10-17 19:46:37 +11:00
parent 3c4f43314c
commit 4012388f0a
9 changed files with 36 additions and 36 deletions

View File

@ -9,8 +9,8 @@ from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.baseinvocation import (
MetadataField,
type_adapter_MetadataField,
type_adapter_WorkflowField,
MetadataFieldValidator,
WorkflowFieldValidator,
)
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
@ -66,7 +66,7 @@ async def upload_image(
metadata_raw = pil_image.info.get("invokeai_metadata", None)
if metadata_raw:
try:
metadata = type_adapter_MetadataField.validate_json(metadata_raw)
metadata = MetadataFieldValidator.validate_json(metadata_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
@ -75,7 +75,7 @@ async def upload_image(
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None:
try:
workflow = type_adapter_WorkflowField.validate_json(workflow_raw)
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass

View File

@ -23,13 +23,13 @@ from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
import_models_response_adapter = TypeAdapter(ImportModelResponse)
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
@ -41,7 +41,7 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
models_list_adapter = TypeAdapter(ModelsList)
ModelsListValidator = TypeAdapter(ModelsList)
@models_router.get(
@ -60,7 +60,7 @@ async def list_models(
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = models_list_adapter.validate_python({"models": models_raw})
models = ModelsListValidator.validate_python({"models": models_raw})
return models
@ -131,7 +131,7 @@ async def update_model(
base_model=base_model,
model_type=model_type,
)
model_response = update_models_response_adapter.validate_python(model_raw)
model_response = UpdateModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -186,7 +186,7 @@ async def import_model(
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
)
return import_models_response_adapter.validate_python(model_raw)
return ImportModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
@ -231,7 +231,7 @@ async def add_model(
base_model=info.base_model,
model_type=info.model_type,
)
return import_models_response_adapter.validate_python(model_raw)
return ImportModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@ -302,7 +302,7 @@ async def convert_model(
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
response = convert_models_response_adapter.validate_python(model_raw)
response = ConvertModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
@ -417,7 +417,7 @@ async def merge_models(
base_model=base_model,
model_type=ModelType.Main,
)
response = convert_models_response_adapter.validate_python(model_raw)
response = ConvertModelResponseValidator.validate_python(model_raw)
except ModelNotFoundException:
raise HTTPException(
status_code=404,

View File

@ -821,7 +821,7 @@ class WorkflowField(RootModel):
root: dict[str, Any] = Field(description="The workflow")
type_adapter_WorkflowField = TypeAdapter(WorkflowField)
WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
@ -837,7 +837,7 @@ class MetadataField(RootModel):
root: dict[str, Any] = Field(description="The metadata")
type_adapter_MetadataField = TypeAdapter(MetadataField)
MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel):

View File

@ -3,7 +3,7 @@ import threading
from datetime import datetime
from typing import Optional, Union, cast
from invokeai.app.invocations.baseinvocation import MetadataField, type_adapter_MetadataField
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
@ -170,7 +170,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return type_adapter_MetadataField.validate_json(metadata_raw) if metadata_raw is not None else None
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e

View File

@ -18,7 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
_cursor: sqlite3.Cursor
_id_field: str
_lock: threading.RLock
_adapter: Optional[TypeAdapter[T]]
_validator: Optional[TypeAdapter[T]]
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
super().__init__()
@ -28,7 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._cursor = self._conn.cursor()
self._adapter: Optional[TypeAdapter[T]] = None
self._validator: Optional[TypeAdapter[T]] = None
self._create_table()
@ -47,14 +47,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release()
def _parse_item(self, item: str) -> T:
if self._adapter is None:
if self._validator is None:
"""
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
we can create it when it is first needed instead.
__orig_class__ is technically an implementation detail of the typing module, not a supported API
"""
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
return self._adapter.validate_json(item)
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
return self._validator.validate_json(item)
def set(self, item: T):
try:

View File

@ -147,20 +147,20 @@ DEFAULT_QUEUE_ID = "default"
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
field_values_raw = queue_item_dict.get("field_values", None)
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
def get_session(queue_item_dict: dict) -> GraphExecutionState:
session_raw = queue_item_dict.get("session", "{}")
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
return session

View File

@ -1,7 +1,7 @@
import sqlite3
import threading
from invokeai.app.invocations.baseinvocation import WorkflowField, type_adapter_WorkflowField
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -39,7 +39,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
row = self._cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return type_adapter_WorkflowField.validate_json(row[0])
return WorkflowFieldValidator.validate_json(row[0])
except Exception:
self._conn.rollback()
raise

View File

@ -615,8 +615,8 @@ def test_graph_can_deserialize():
g.add_edge(e)
json = g.model_dump_json()
adapter_graph = TypeAdapter(Graph)
g2 = adapter_graph.validate_json(json)
GraphValidator = TypeAdapter(Graph)
g2 = GraphValidator.validate_json(json)
assert g2 is not None
assert g2.nodes["1"] is not None

View File

@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
assert len(values) == 8
session_adapter = TypeAdapter(GraphExecutionState)
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
# graph should be serialized
ges = session_adapter.validate_json(values[0].session)
ges = GraphExecutionStateValidator.validate_json(values[0].session)
# graph values should be populated
assert ges.graph.get_node("1").prompt == "Banana sushi"
@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
assert ges.graph.get_node("4").prompt == "Nissan"
# session ids should match deserialized graph
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
# should unique session ids
sids = [v.session_id for v in values]
assert len(sids) == len(set(sids))
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert type(values[0].field_values) is str
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority
assert all(v.batch_id == b.batch_id for v in values)