mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph. Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports. For example, consider a setup where we have 3 invocations in the app: - Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`. - Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`. - Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`. - Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`. - A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type. This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain. This PR refactors the validation of graph nodes to resolve this issue: - `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call. - `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`. - A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors. `BaseInvocationOutput` gets the same treatment.
This commit is contained in:
@ -417,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
|
Reference in New Issue
Block a user