mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c238a7f18b
Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
316 lines
11 KiB
Python
316 lines
11 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from typing import Any, Optional
|
|
|
|
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
|
from invokeai.app.services.session_queue.session_queue_common import (
|
|
BatchStatus,
|
|
EnqueueBatchResult,
|
|
SessionQueueItem,
|
|
SessionQueueStatus,
|
|
)
|
|
from invokeai.app.util.misc import get_timestamp
|
|
from invokeai.backend.model_management.model_manager import ModelInfo
|
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
|
|
|
|
|
class EventServiceBase:
|
|
queue_event: str = "queue_event"
|
|
|
|
"""Basic event bus, to have an empty stand-in when not needed"""
|
|
|
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
|
pass
|
|
|
|
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
|
"""Queue events are emitted to a room with queue_id as the room name"""
|
|
payload["timestamp"] = get_timestamp()
|
|
self.dispatch(
|
|
event_name=EventServiceBase.queue_event,
|
|
payload=dict(event=event_name, data=payload),
|
|
)
|
|
|
|
# Define events here for every event in the system.
|
|
# This will make them easier to integrate until we find a schema generator.
|
|
def emit_generator_progress(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
node: dict,
|
|
source_node_id: str,
|
|
progress_image: Optional[ProgressImage],
|
|
step: int,
|
|
order: int,
|
|
total_steps: int,
|
|
) -> None:
|
|
"""Emitted when there is generation progress"""
|
|
self.__emit_queue_event(
|
|
event_name="generator_progress",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
node_id=node.get("id"),
|
|
source_node_id=source_node_id,
|
|
progress_image=progress_image.model_dump() if progress_image is not None else None,
|
|
step=step,
|
|
order=order,
|
|
total_steps=total_steps,
|
|
),
|
|
)
|
|
|
|
def emit_invocation_complete(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
result: dict,
|
|
node: dict,
|
|
source_node_id: str,
|
|
) -> None:
|
|
"""Emitted when an invocation has completed"""
|
|
self.__emit_queue_event(
|
|
event_name="invocation_complete",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
node=node,
|
|
source_node_id=source_node_id,
|
|
result=result,
|
|
),
|
|
)
|
|
|
|
def emit_invocation_error(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
node: dict,
|
|
source_node_id: str,
|
|
error_type: str,
|
|
error: str,
|
|
) -> None:
|
|
"""Emitted when an invocation has completed"""
|
|
self.__emit_queue_event(
|
|
event_name="invocation_error",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
node=node,
|
|
source_node_id=source_node_id,
|
|
error_type=error_type,
|
|
error=error,
|
|
),
|
|
)
|
|
|
|
def emit_invocation_started(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
node: dict,
|
|
source_node_id: str,
|
|
) -> None:
|
|
"""Emitted when an invocation has started"""
|
|
self.__emit_queue_event(
|
|
event_name="invocation_started",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
node=node,
|
|
source_node_id=source_node_id,
|
|
),
|
|
)
|
|
|
|
def emit_graph_execution_complete(
|
|
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
|
) -> None:
|
|
"""Emitted when a session has completed all invocations"""
|
|
self.__emit_queue_event(
|
|
event_name="graph_execution_state_complete",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
),
|
|
)
|
|
|
|
def emit_model_load_started(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel: SubModelType,
|
|
) -> None:
|
|
"""Emitted when a model is requested"""
|
|
self.__emit_queue_event(
|
|
event_name="model_load_started",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
),
|
|
)
|
|
|
|
def emit_model_load_completed(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel: SubModelType,
|
|
model_info: ModelInfo,
|
|
) -> None:
|
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
|
self.__emit_queue_event(
|
|
event_name="model_load_completed",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
hash=model_info.hash,
|
|
location=str(model_info.location),
|
|
precision=str(model_info.precision),
|
|
),
|
|
)
|
|
|
|
def emit_session_retrieval_error(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
error_type: str,
|
|
error: str,
|
|
) -> None:
|
|
"""Emitted when session retrieval fails"""
|
|
self.__emit_queue_event(
|
|
event_name="session_retrieval_error",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
error_type=error_type,
|
|
error=error,
|
|
),
|
|
)
|
|
|
|
def emit_invocation_retrieval_error(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
node_id: str,
|
|
error_type: str,
|
|
error: str,
|
|
) -> None:
|
|
"""Emitted when invocation retrieval fails"""
|
|
self.__emit_queue_event(
|
|
event_name="invocation_retrieval_error",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
node_id=node_id,
|
|
error_type=error_type,
|
|
error=error,
|
|
),
|
|
)
|
|
|
|
def emit_session_canceled(
|
|
self,
|
|
queue_id: str,
|
|
queue_item_id: int,
|
|
queue_batch_id: str,
|
|
graph_execution_state_id: str,
|
|
) -> None:
|
|
"""Emitted when a session is canceled"""
|
|
self.__emit_queue_event(
|
|
event_name="session_canceled",
|
|
payload=dict(
|
|
queue_id=queue_id,
|
|
queue_item_id=queue_item_id,
|
|
queue_batch_id=queue_batch_id,
|
|
graph_execution_state_id=graph_execution_state_id,
|
|
),
|
|
)
|
|
|
|
def emit_queue_item_status_changed(
|
|
self,
|
|
session_queue_item: SessionQueueItem,
|
|
batch_status: BatchStatus,
|
|
queue_status: SessionQueueStatus,
|
|
) -> None:
|
|
"""Emitted when a queue item's status changes"""
|
|
self.__emit_queue_event(
|
|
event_name="queue_item_status_changed",
|
|
payload=dict(
|
|
queue_id=queue_status.queue_id,
|
|
queue_item=dict(
|
|
queue_id=session_queue_item.queue_id,
|
|
item_id=session_queue_item.item_id,
|
|
status=session_queue_item.status,
|
|
batch_id=session_queue_item.batch_id,
|
|
session_id=session_queue_item.session_id,
|
|
error=session_queue_item.error,
|
|
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
|
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
|
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
|
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
|
),
|
|
batch_status=batch_status.model_dump(),
|
|
queue_status=queue_status.model_dump(),
|
|
),
|
|
)
|
|
|
|
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
|
"""Emitted when a batch is enqueued"""
|
|
self.__emit_queue_event(
|
|
event_name="batch_enqueued",
|
|
payload=dict(
|
|
queue_id=enqueue_result.queue_id,
|
|
batch_id=enqueue_result.batch.batch_id,
|
|
enqueued=enqueue_result.enqueued,
|
|
),
|
|
)
|
|
|
|
def emit_queue_cleared(self, queue_id: str) -> None:
|
|
"""Emitted when the queue is cleared"""
|
|
self.__emit_queue_event(
|
|
event_name="queue_cleared",
|
|
payload=dict(queue_id=queue_id),
|
|
)
|