mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: use ModelValidator
naming convention for pydantic type adapters
This is the naming convention in the docs and is also clear.
This commit is contained in:
parent
3c4f43314c
commit
4012388f0a
@ -9,8 +9,8 @@ from pydantic import BaseModel, Field, ValidationError
|
|||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
MetadataField,
|
MetadataField,
|
||||||
type_adapter_MetadataField,
|
MetadataFieldValidator,
|
||||||
type_adapter_WorkflowField,
|
WorkflowFieldValidator,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
@ -66,7 +66,7 @@ async def upload_image(
|
|||||||
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
||||||
if metadata_raw:
|
if metadata_raw:
|
||||||
try:
|
try:
|
||||||
metadata = type_adapter_MetadataField.validate_json(metadata_raw)
|
metadata = MetadataFieldValidator.validate_json(metadata_raw)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||||
pass
|
pass
|
||||||
@ -75,7 +75,7 @@ async def upload_image(
|
|||||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||||
if workflow_raw is not None:
|
if workflow_raw is not None:
|
||||||
try:
|
try:
|
||||||
workflow = type_adapter_WorkflowField.validate_json(workflow_raw)
|
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||||
pass
|
pass
|
||||||
|
@ -23,13 +23,13 @@ from ..dependencies import ApiDependencies
|
|||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
||||||
|
|
||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
||||||
|
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
||||||
|
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
@ -41,7 +41,7 @@ class ModelsList(BaseModel):
|
|||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
models_list_adapter = TypeAdapter(ModelsList)
|
ModelsListValidator = TypeAdapter(ModelsList)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -60,7 +60,7 @@ async def list_models(
|
|||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = models_list_adapter.validate_python({"models": models_raw})
|
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
model_response = update_models_response_adapter.validate_python(model_raw)
|
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -186,7 +186,7 @@ async def import_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||||
)
|
)
|
||||||
return import_models_response_adapter.validate_python(model_raw)
|
return ImportModelResponseValidator.validate_python(model_raw)
|
||||||
|
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@ -231,7 +231,7 @@ async def add_model(
|
|||||||
base_model=info.base_model,
|
base_model=info.base_model,
|
||||||
model_type=info.model_type,
|
model_type=info.model_type,
|
||||||
)
|
)
|
||||||
return import_models_response_adapter.validate_python(model_raw)
|
return ImportModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -302,7 +302,7 @@ async def convert_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name, base_model=base_model, model_type=model_type
|
model_name, base_model=base_model, model_type=model_type
|
||||||
)
|
)
|
||||||
response = convert_models_response_adapter.validate_python(model_raw)
|
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -417,7 +417,7 @@ async def merge_models(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Main,
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
response = convert_models_response_adapter.validate_python(model_raw)
|
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
|
@ -821,7 +821,7 @@ class WorkflowField(RootModel):
|
|||||||
root: dict[str, Any] = Field(description="The workflow")
|
root: dict[str, Any] = Field(description="The workflow")
|
||||||
|
|
||||||
|
|
||||||
type_adapter_WorkflowField = TypeAdapter(WorkflowField)
|
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||||
|
|
||||||
|
|
||||||
class WithWorkflow(BaseModel):
|
class WithWorkflow(BaseModel):
|
||||||
@ -837,7 +837,7 @@ class MetadataField(RootModel):
|
|||||||
root: dict[str, Any] = Field(description="The metadata")
|
root: dict[str, Any] = Field(description="The metadata")
|
||||||
|
|
||||||
|
|
||||||
type_adapter_MetadataField = TypeAdapter(MetadataField)
|
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||||
|
|
||||||
|
|
||||||
class WithMetadata(BaseModel):
|
class WithMetadata(BaseModel):
|
||||||
|
@ -3,7 +3,7 @@ import threading
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField, type_adapter_MetadataField
|
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
as_dict = dict(result)
|
as_dict = dict(result)
|
||||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||||
return type_adapter_MetadataField.validate_json(metadata_raw) if metadata_raw is not None else None
|
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise ImageRecordNotFoundException from e
|
raise ImageRecordNotFoundException from e
|
||||||
|
@ -18,7 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
_adapter: Optional[TypeAdapter[T]]
|
_validator: Optional[TypeAdapter[T]]
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -28,7 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._adapter: Optional[TypeAdapter[T]] = None
|
self._validator: Optional[TypeAdapter[T]] = None
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
|
||||||
@ -47,14 +47,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
if self._adapter is None:
|
if self._validator is None:
|
||||||
"""
|
"""
|
||||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||||
we can create it when it is first needed instead.
|
we can create it when it is first needed instead.
|
||||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||||
"""
|
"""
|
||||||
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||||
return self._adapter.validate_json(item)
|
return self._validator.validate_json(item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
@ -147,20 +147,20 @@ DEFAULT_QUEUE_ID = "default"
|
|||||||
|
|
||||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||||
|
|
||||||
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||||
|
|
||||||
|
|
||||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||||
field_values_raw = queue_item_dict.get("field_values", None)
|
field_values_raw = queue_item_dict.get("field_values", None)
|
||||||
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||||
|
|
||||||
|
|
||||||
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||||
|
|
||||||
|
|
||||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||||
session_raw = queue_item_dict.get("session", "{}")
|
session_raw = queue_item_dict.get("session", "{}")
|
||||||
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import WorkflowField, type_adapter_WorkflowField
|
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||||
@ -39,7 +39,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
|||||||
row = self._cursor.fetchone()
|
row = self._cursor.fetchone()
|
||||||
if row is None:
|
if row is None:
|
||||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||||
return type_adapter_WorkflowField.validate_json(row[0])
|
return WorkflowFieldValidator.validate_json(row[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise
|
raise
|
||||||
|
@ -615,8 +615,8 @@ def test_graph_can_deserialize():
|
|||||||
g.add_edge(e)
|
g.add_edge(e)
|
||||||
|
|
||||||
json = g.model_dump_json()
|
json = g.model_dump_json()
|
||||||
adapter_graph = TypeAdapter(Graph)
|
GraphValidator = TypeAdapter(Graph)
|
||||||
g2 = adapter_graph.validate_json(json)
|
g2 = GraphValidator.validate_json(json)
|
||||||
|
|
||||||
assert g2 is not None
|
assert g2 is not None
|
||||||
assert g2.nodes["1"] is not None
|
assert g2.nodes["1"] is not None
|
||||||
|
@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||||
assert len(values) == 8
|
assert len(values) == 8
|
||||||
|
|
||||||
session_adapter = TypeAdapter(GraphExecutionState)
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||||
# graph should be serialized
|
# graph should be serialized
|
||||||
ges = session_adapter.validate_json(values[0].session)
|
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
||||||
|
|
||||||
# graph values should be populated
|
# graph values should be populated
|
||||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||||
@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||||
|
|
||||||
# session ids should match deserialized graph
|
# session ids should match deserialized graph
|
||||||
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
||||||
|
|
||||||
# should unique session ids
|
# should unique session ids
|
||||||
sids = [v.session_id for v in values]
|
sids = [v.session_id for v in values]
|
||||||
assert len(sids) == len(set(sids))
|
assert len(sids) == len(set(sids))
|
||||||
|
|
||||||
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||||
# should have 3 node field values
|
# should have 3 node field values
|
||||||
assert type(values[0].field_values) is str
|
assert type(values[0].field_values) is str
|
||||||
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
||||||
|
|
||||||
# should have batch id and priority
|
# should have batch id and priority
|
||||||
assert all(v.batch_id == b.batch_id for v in values)
|
assert all(v.batch_id == b.batch_id for v in values)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user