mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/nodes/freeu
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import inspect
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@ -11,8 +11,8 @@ from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
@ -26,6 +26,10 @@ class InvalidVersionError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFieldError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
@ -60,7 +64,12 @@ class FieldDescriptions:
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_collection = "Collection of Metadata"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
@ -171,8 +180,12 @@ class UIType(str, Enum):
|
||||
Scheduler = "Scheduler"
|
||||
WorkflowField = "WorkflowField"
|
||||
IsIntermediate = "IsIntermediate"
|
||||
MetadataField = "MetadataField"
|
||||
BoardField = "BoardField"
|
||||
Any = "Any"
|
||||
MetadataItem = "MetadataItem"
|
||||
MetadataItemCollection = "MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||
MetadataDict = "MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
@ -298,6 +311,7 @@ def InputField(
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
_field_kind="input",
|
||||
)
|
||||
|
||||
field_args = dict(
|
||||
@ -440,6 +454,7 @@ def OutputField(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
_field_kind="output",
|
||||
),
|
||||
)
|
||||
|
||||
@ -523,6 +538,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
@ -545,9 +561,6 @@ class MissingInputException(Exception):
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
@ -663,46 +676,93 @@ class BaseInvocation(ABC, BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
json_schema_extra=dict(_field_kind="internal"),
|
||||
)
|
||||
is_intermediate: Optional[bool] = Field(
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
||||
)
|
||||
workflow: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||
use_cache: bool = Field(
|
||||
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
use_cache: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
)
|
||||
|
||||
@field_validator("workflow", mode="before")
|
||||
@classmethod
|
||||
def validate_workflow_is_json(cls, v):
|
||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
json.loads(v)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError("Workflow must be valid JSON")
|
||||
return v
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
"metadata",
|
||||
}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
|
||||
|
||||
class _Model(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
# Get all pydantic model attrs, methods, etc
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
- must not override any pydantic reserved fields
|
||||
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None)
|
||||
if field.json_schema_extra
|
||||
else None
|
||||
)
|
||||
|
||||
# must have a field_kind
|
||||
if field_kind is None or field_kind not in {"input", "output", "internal"}:
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||
)
|
||||
|
||||
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||
|
||||
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||
|
||||
# internal fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind == "internal"
|
||||
and name not in RESERVED_INPUT_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
@ -712,7 +772,7 @@ def invocation(
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
Registers an invocation.
|
||||
|
||||
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
||||
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||
@ -731,6 +791,8 @@ def invocation(
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
@ -761,8 +823,7 @@ def invocation(
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type",
|
||||
default=invocation_type,
|
||||
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
@ -803,13 +864,12 @@ def invocation_output(
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type",
|
||||
default=output_type,
|
||||
)
|
||||
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
@ -827,4 +887,37 @@ def invocation_output(
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
class WorkflowField(RootModel):
|
||||
"""
|
||||
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||
Workflows are stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The workflow")
|
||||
|
||||
|
||||
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = Field(
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
|
Reference in New Issue
Block a user