mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ebr/docker-py311
This commit is contained in:
commit
1177234931
@ -2,6 +2,7 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@ -30,6 +31,7 @@ from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite import SqliteDatabase
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -90,6 +92,8 @@ class ApiDependencies:
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@ -114,6 +118,8 @@ class ApiDependencies:
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_image_records=workflow_image_records,
|
||||
workflow_records=workflow_records,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
@ -1,13 +1,14 @@
|
||||
import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
@ -45,17 +46,38 @@ async def upload_image(
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
metadata = None
|
||||
workflow = None
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except Exception:
|
||||
# Error opening the image
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
# TODO: retain non-invokeai metadata on upload?
|
||||
# attempt to parse metadata from image
|
||||
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
||||
if metadata_raw:
|
||||
try:
|
||||
metadata = MetadataFieldValidator.validate_json(metadata_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to parse workflow from image
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if workflow_raw is not None:
|
||||
try:
|
||||
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
image=pil_image,
|
||||
@ -63,6 +85,8 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
board_id=board_id,
|
||||
metadata=metadata,
|
||||
workflow=workflow,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
@ -71,6 +95,7 @@ async def upload_image(
|
||||
|
||||
return image_dto
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@ -146,11 +171,11 @@ async def get_image_dto(
|
||||
@images_router.get(
|
||||
"/i/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageMetadata,
|
||||
response_model=Optional[MetadataField],
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageMetadata:
|
||||
) -> Optional[MetadataField]:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
|
@ -23,13 +23,13 @@ from ..dependencies import ApiDependencies
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
||||
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
||||
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
||||
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
||||
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
||||
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
||||
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
@ -41,7 +41,7 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
models_list_adapter = TypeAdapter(ModelsList)
|
||||
ModelsListValidator = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -60,7 +60,7 @@ async def list_models(
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = models_list_adapter.validate_python({"models": models_raw})
|
||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ async def update_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = update_models_response_adapter.validate_python(model_raw)
|
||||
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@ -186,7 +186,7 @@ async def import_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
@ -231,7 +231,7 @@ async def add_model(
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type,
|
||||
)
|
||||
return import_models_response_adapter.validate_python(model_raw)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -302,7 +302,7 @@ async def convert_model(
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
@ -417,7 +417,7 @@ async def merge_models(
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = convert_models_response_adapter.validate_python(model_raw)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
20
invokeai/app/api/routers/workflows.py
Normal file
20
invokeai/app/api/routers/workflows.py
Normal file
@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter, Path
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/i/{workflow_id}",
|
||||
operation_id="get_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowField},
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowField:
|
||||
"""Gets a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
@ -38,7 +38,17 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||
from .api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
images,
|
||||
models,
|
||||
session_queue,
|
||||
sessions,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
|
||||
@ -95,18 +105,13 @@ async def shutdown_event() -> None:
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
@ -166,7 +171,6 @@ def custom_openapi() -> dict[str, Any]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
|
@ -1,8 +1,28 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = []
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
__all__.append(f[:-3])
|
||||
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.absolute())
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
|
||||
|
||||
# copy our custom nodes __init__.py to the custom nodes directory
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
|
||||
|
||||
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# add core nodes to __all__
|
||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||
__all__ = list(f.stem for f in python_files) # type: ignore
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import inspect
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@ -11,8 +11,8 @@ from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
@ -26,6 +26,10 @@ class InvalidVersionError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFieldError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
@ -60,7 +64,12 @@ class FieldDescriptions:
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_collection = "Collection of Metadata"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
@ -167,8 +176,12 @@ class UIType(str, Enum):
|
||||
Scheduler = "Scheduler"
|
||||
WorkflowField = "WorkflowField"
|
||||
IsIntermediate = "IsIntermediate"
|
||||
MetadataField = "MetadataField"
|
||||
BoardField = "BoardField"
|
||||
Any = "Any"
|
||||
MetadataItem = "MetadataItem"
|
||||
MetadataItemCollection = "MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||
MetadataDict = "MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
@ -294,6 +307,7 @@ def InputField(
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
_field_kind="input",
|
||||
)
|
||||
|
||||
field_args = dict(
|
||||
@ -436,6 +450,7 @@ def OutputField(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
_field_kind="output",
|
||||
),
|
||||
)
|
||||
|
||||
@ -519,6 +534,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
@ -541,9 +557,6 @@ class MissingInputException(Exception):
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
@ -659,37 +672,21 @@ class BaseInvocation(ABC, BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
json_schema_extra=dict(_field_kind="internal"),
|
||||
)
|
||||
is_intermediate: Optional[bool] = Field(
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
||||
)
|
||||
workflow: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow to save with the image",
|
||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
||||
use_cache: bool = Field(
|
||||
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
use_cache: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
)
|
||||
|
||||
@field_validator("workflow", mode="before")
|
||||
@classmethod
|
||||
def validate_workflow_is_json(cls, v):
|
||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
||||
if v is None:
|
||||
return None
|
||||
try:
|
||||
json.loads(v)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError("Workflow must be valid JSON")
|
||||
return v
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
@ -700,6 +697,68 @@ class BaseInvocation(ABC, BaseModel):
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
"metadata",
|
||||
}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
|
||||
|
||||
class _Model(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
# Get all pydantic model attrs, methods, etc
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
- must not override any pydantic reserved fields
|
||||
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None)
|
||||
if field.json_schema_extra
|
||||
else None
|
||||
)
|
||||
|
||||
# must have a field_kind
|
||||
if field_kind is None or field_kind not in {"input", "output", "internal"}:
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||
)
|
||||
|
||||
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||
|
||||
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||
|
||||
# internal fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind == "internal"
|
||||
and name not in RESERVED_INPUT_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
@ -709,7 +768,7 @@ def invocation(
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
Registers an invocation.
|
||||
|
||||
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
||||
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||
@ -728,6 +787,8 @@ def invocation(
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
@ -758,8 +819,7 @@ def invocation(
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type",
|
||||
default=invocation_type,
|
||||
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
@ -800,13 +860,12 @@ def invocation_output(
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type",
|
||||
default=output_type,
|
||||
)
|
||||
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
@ -824,4 +883,37 @@ def invocation_output(
|
||||
return wrapper
|
||||
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
class WorkflowField(RootModel):
|
||||
"""
|
||||
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||
Workflows are stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The workflow")
|
||||
|
||||
|
||||
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = Field(
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
||||
)
|
||||
|
@ -38,6 +38,8 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -127,12 +129,12 @@ class ControlNetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
@ -150,6 +152,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
51
invokeai/app/invocations/custom_nodes/README.md
Normal file
51
invokeai/app/invocations/custom_nodes/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
# Custom Nodes / Node Packs
|
||||
|
||||
Copy your node packs to this directory.
|
||||
|
||||
When nodes are added or changed, you must restart the app to see the changes.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
For a node pack to be loaded, it must be placed in a directory alongside this
|
||||
file. Here's an example structure:
|
||||
|
||||
```py
|
||||
.
|
||||
├── __init__.py # Invoke-managed custom node loader
|
||||
│
|
||||
├── cool_node
|
||||
│ ├── __init__.py # see example below
|
||||
│ └── cool_node.py
|
||||
│
|
||||
└── my_node_pack
|
||||
├── __init__.py # see example below
|
||||
├── tasty_node.py
|
||||
├── bodacious_node.py
|
||||
├── utils.py
|
||||
└── extra_nodes
|
||||
└── fancy_node.py
|
||||
```
|
||||
|
||||
## Node Pack `__init__.py`
|
||||
|
||||
Each node pack must have an `__init__.py` file that imports its nodes.
|
||||
|
||||
The structure of each node or node pack is otherwise not important.
|
||||
|
||||
Here are examples, based on the example directory structure.
|
||||
|
||||
### `cool_node/__init__.py`
|
||||
|
||||
```py
|
||||
from .cool_node import CoolInvocation
|
||||
```
|
||||
|
||||
### `my_node_pack/__init__.py`
|
||||
|
||||
```py
|
||||
from .tasty_node import TastyInvocation
|
||||
from .bodacious_node import BodaciousInvocation
|
||||
from .extra_nodes.fancy_node import FancyInvocation
|
||||
```
|
||||
|
||||
Only nodes imported in the `__init__.py` file are loaded.
|
51
invokeai/app/invocations/custom_nodes/init.py
Normal file
51
invokeai/app/invocations/custom_nodes/init.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""
|
||||
Invoke-managed custom node loader. See README.md for more information.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
for d in Path(__file__).parent.iterdir():
|
||||
# skip files
|
||||
if not d.is_dir():
|
||||
continue
|
||||
|
||||
# skip hidden directories
|
||||
if d.name.startswith("_") or d.name.startswith("."):
|
||||
continue
|
||||
|
||||
# skip directories without an `__init__.py`
|
||||
init = d / "__init__.py"
|
||||
if not init.exists():
|
||||
continue
|
||||
|
||||
module_name = init.parent.stem
|
||||
|
||||
# skip if already imported
|
||||
if module_name in globals():
|
||||
continue
|
||||
|
||||
# we have a legit module to import
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warn(f"Could not load {init}")
|
||||
continue
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_count += 1
|
||||
|
||||
del init, module_name
|
||||
|
||||
|
||||
logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}")
|
@ -8,11 +8,11 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
|
@ -16,6 +16,8 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -437,7 +439,7 @@ def get_faces_list(
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.2")
|
||||
class FaceOffInvocation(BaseInvocation):
|
||||
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image for face detection")
|
||||
@ -531,7 +533,7 @@ class FaceOffInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.2")
|
||||
class FaceMaskInvocation(BaseInvocation):
|
||||
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@ -650,7 +652,7 @@ class FaceMaskInvocation(BaseInvocation):
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.2"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation):
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
|
@ -7,13 +7,21 @@ import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
)
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
@ -36,14 +44,8 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"blank_image",
|
||||
title="Blank Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
width: int = InputField(default=512, description="The width of the image")
|
||||
@ -61,6 +63,7 @@ class BlankImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -71,14 +74,8 @@ class BlankImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
tags=["image", "crop"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
@ -100,6 +97,7 @@ class ImageCropInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -110,14 +108,8 @@ class ImageCropInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_paste",
|
||||
title="Paste Image",
|
||||
tags=["image", "paste"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
||||
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
@ -159,6 +151,7 @@ class ImagePasteInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -169,14 +162,8 @@ class ImagePasteInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
@ -196,6 +183,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -206,14 +194,8 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_mul",
|
||||
title="Multiply Images",
|
||||
tags=["image", "multiply"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
@ -232,6 +214,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -245,14 +228,8 @@ class ImageMultiplyInvocation(BaseInvocation):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_chan",
|
||||
title="Extract Image Channel",
|
||||
tags=["image", "channel"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
@ -270,6 +247,7 @@ class ImageChannelInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -283,14 +261,8 @@ class ImageChannelInvocation(BaseInvocation):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_conv",
|
||||
title="Convert Image Mode",
|
||||
tags=["image", "convert"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
@ -308,6 +280,7 @@ class ImageConvertInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -318,14 +291,8 @@ class ImageConvertInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_blur",
|
||||
title="Blur Image",
|
||||
tags=["image", "blur"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Blurs an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
@ -348,6 +315,7 @@ class ImageBlurInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -378,23 +346,14 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_resize",
|
||||
title="Resize Image",
|
||||
tags=["image", "resize"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, gt=0, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -413,7 +372,7 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -424,14 +383,8 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_scale",
|
||||
title="Scale Image",
|
||||
tags=["image", "scale"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
@ -461,6 +414,7 @@ class ImageScaleInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -471,14 +425,8 @@ class ImageScaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_lerp",
|
||||
title="Lerp Image",
|
||||
tags=["image", "lerp"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@ -500,6 +448,7 @@ class ImageLerpInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -510,14 +459,8 @@ class ImageLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_ilerp",
|
||||
title="Inverse Lerp Image",
|
||||
tags=["image", "ilerp"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@ -539,6 +482,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -549,20 +493,11 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_nsfw",
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -583,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -607,14 +542,11 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -626,7 +558,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -637,14 +569,8 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
@ -678,6 +604,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -695,7 +622,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
@ -714,6 +641,7 @@ class MaskCombineInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -724,14 +652,8 @@ class MaskCombineInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_correct",
|
||||
title="Color Correct",
|
||||
tags=["image", "color"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
@ -830,6 +752,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -840,14 +763,8 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_hue_adjust",
|
||||
title="Adjust Image Hue",
|
||||
tags=["image", "hue"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@ -875,6 +792,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -950,7 +868,7 @@ CHANNEL_FORMATS = {
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@ -984,6 +902,7 @@ class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -1020,7 +939,7 @@ class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Scale a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@ -1060,6 +979,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
workflow=self.workflow,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -1079,16 +999,11 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation):
|
||||
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@ -1101,7 +1016,7 @@ class SaveImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -143,6 +143,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -154,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -179,6 +180,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -192,7 +194,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -232,6 +234,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
@ -243,7 +246,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class LaMaInfillInvocation(BaseInvocation):
|
||||
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -260,6 +263,8 @@ class LaMaInfillInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -269,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class CV2InfillInvocation(BaseInvocation):
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -287,6 +292,8 @@ class CV2InfillInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -23,7 +23,6 @@ from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
@ -64,6 +63,8 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -792,7 +793,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@ -805,11 +806,6 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -878,7 +874,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
@ -1,13 +1,16 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
MetadataField,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -16,116 +19,99 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModelExcludeNull):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
class MetadataItemField(BaseModel):
|
||||
label: str = Field(description=FieldDescriptions.metadata_item_label)
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
|
||||
|
||||
class IPAdapterMetadataField(BaseModelExcludeNull):
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
weight: float = Field(description="The weight of the IP-Adapter model")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
ip_adapter_model: IPAdapterModelField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
weight: Union[float, list[float]] = Field(
|
||||
description="The weight given to the IP-Adapter",
|
||||
)
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
|
||||
item: MetadataItemField = OutputField(description="Metadata Item")
|
||||
|
||||
|
||||
@invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataItemInvocation(BaseInvocation):
|
||||
"""Used to create an arbitrary metadata item. Provide "label" and make a connection to "value" to store that data as the value."""
|
||||
|
||||
label: str = InputField(description=FieldDescriptions.metadata_item_label)
|
||||
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataItemOutput:
|
||||
return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value))
|
||||
|
||||
|
||||
@invocation_output("metadata_output")
|
||||
class MetadataOutput(BaseInvocationOutput):
|
||||
metadata: MetadataField = OutputField(description="Metadata Dict")
|
||||
|
||||
|
||||
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataInvocation(BaseInvocation):
|
||||
"""Takes a MetadataItem or collection of MetadataItems and outputs a MetadataDict."""
|
||||
|
||||
items: Union[list[MetadataItemField], MetadataItemField] = InputField(
|
||||
description=FieldDescriptions.metadata_item_polymorphic
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
if isinstance(self.items, MetadataItemField):
|
||||
# single metadata item
|
||||
data = {self.items.label: self.items.value}
|
||||
else:
|
||||
# collection of metadata items
|
||||
data = {item.label: item.value for item in self.items}
|
||||
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
created_by: Optional[str] = Field(default=None, description="The name of the creator of the image")
|
||||
positive_prompt: Optional[str] = Field(default=None, description="The positive prompt parameter")
|
||||
negative_prompt: Optional[str] = Field(default=None, description="The negative prompt parameter")
|
||||
width: Optional[int] = Field(default=None, description="The width parameter")
|
||||
height: Optional[int] = Field(default=None, description="The height parameter")
|
||||
seed: Optional[int] = Field(default=None, description="The seed used for noise generation")
|
||||
rand_device: Optional[str] = Field(default=None, description="The device used for random number generation")
|
||||
cfg_scale: Optional[float] = Field(default=None, description="The classifier-free guidance scale parameter")
|
||||
steps: Optional[int] = Field(default=None, description="The number of steps used for inference")
|
||||
scheduler: Optional[str] = Field(default=None, description="The scheduler used for inference")
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[MainModelField] = Field(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = Field(default=None, description="The ControlNets used for inference")
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = Field(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = Field(default=None, description="The IP Adapters used for inference")
|
||||
loras: Optional[list[LoRAMetadataField]] = Field(default=None, description="The LoRAs used for inference")
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_score: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_negative_aesthetic_score: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
# add app version
|
||||
data.update({"app_version": __version__})
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(data))
|
||||
|
||||
|
||||
class ImageMetadata(BaseModelExcludeNull):
|
||||
"""An image's generation metadata"""
|
||||
@invocation("merge_metadata", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MergeMetadataInvocation(BaseInvocation):
|
||||
"""Merged a collection of MetadataDict into a single MetadataDict."""
|
||||
|
||||
metadata: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||
)
|
||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
||||
collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
data = {}
|
||||
for item in self.collection:
|
||||
data.update(item.model_dump())
|
||||
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(data))
|
||||
|
||||
|
||||
@invocation_output("metadata_accumulator_output")
|
||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
"""The output of the MetadataAccumulator node"""
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class CoreMetadataInvocation(BaseInvocation):
|
||||
"""Collects core generation metadata into a MetadataField"""
|
||||
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@invocation(
|
||||
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
||||
)
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
generation_mode: Optional[str] = InputField(
|
||||
generation_mode: Literal["txt2img", "img2img", "inpaint", "outpaint"] = InputField(
|
||||
default=None,
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
@ -138,6 +124,8 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
||||
seamless_y: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the Y axis")
|
||||
clip_skip: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
@ -220,7 +208,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
description="The start value used for refiner denoising",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
||||
return MetadataOutput(
|
||||
metadata=MetadataField.model_validate(
|
||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
@ -4,7 +4,7 @@ import inspect
|
||||
import re
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,7 +12,6 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
@ -31,6 +30,8 @@ from .baseinvocation import (
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -327,7 +328,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@ -338,11 +339,6 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -381,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
@ -251,7 +251,9 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
class ImageInvocation(
|
||||
BaseInvocation,
|
||||
):
|
||||
"""An image primitive value"""
|
||||
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
@ -14,7 +14,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@ -30,7 +30,7 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
@ -123,6 +123,7 @@ class ESRGANInvocation(BaseInvocation):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
|
@ -243,6 +243,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
|
||||
# LOGGING
|
||||
@ -410,6 +411,13 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""
|
||||
Path to the custom nodes directory
|
||||
"""
|
||||
return self._resolve(self.custom_nodes_dir)
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self) -> bool:
|
||||
|
@ -4,6 +4,8 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
@ -30,8 +32,8 @@ class ImageFileStorageBase(ABC):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import json
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
@ -8,6 +7,7 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
@ -55,8 +55,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
@ -65,20 +65,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None or workflow is not None:
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json())
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
else:
|
||||
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
|
||||
# TODO: retain non-invokeai metadata on save...
|
||||
original_metadata = image.info.get("invokeai_metadata", None)
|
||||
if original_metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", original_metadata)
|
||||
original_workflow = image.info.get("invokeai_workflow", None)
|
||||
if original_workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json())
|
||||
|
||||
image.save(
|
||||
image_path,
|
||||
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.metadata import MetadataField
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||
@ -18,7 +19,7 @@ class ImageRecordStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""Gets an image's metadata'."""
|
||||
pass
|
||||
|
||||
@ -78,7 +79,7 @@ class ImageRecordStorageBase(ABC):
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
|
||||
@ -141,22 +141,26 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.metadata FROM images
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
if not result or not result[0]:
|
||||
return None
|
||||
return json.loads(result[0])
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
@ -408,10 +412,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
metadata_json = metadata.model_dump_json() if metadata is not None else None
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
|
@ -3,7 +3,7 @@ from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
@ -50,8 +50,8 @@ class ImageServiceABC(ABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -81,7 +81,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""Gets an image's metadata."""
|
||||
pass
|
||||
|
||||
|
@ -24,8 +24,11 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
default=None, description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow that generated this image.",
|
||||
)
|
||||
"""The workflow that generated this image."""
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
@ -33,6 +36,7 @@ def image_record_to_dto(
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
workflow_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
@ -40,4 +44,5 @@ def image_record_to_dto(
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
@ -2,10 +2,9 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
@ -42,8 +41,8 @@ class ImageService(ImageServiceABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[dict] = None,
|
||||
workflow: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@ -56,6 +55,12 @@ class ImageService(ImageServiceABC):
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
if workflow is not None:
|
||||
created_workflow = self.__invoker.services.workflow_records.create(workflow)
|
||||
workflow_id = created_workflow.model_dump()["id"]
|
||||
else:
|
||||
workflow_id = None
|
||||
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self.__invoker.services.image_records.save(
|
||||
# Non-nullable fields
|
||||
@ -73,6 +78,8 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
if workflow_id is not None:
|
||||
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||
)
|
||||
@ -132,10 +139,11 @@ class ImageService(ImageServiceABC):
|
||||
image_record = self.__invoker.services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self.__invoker.services.urls.get_image_url(image_name),
|
||||
self.__invoker.services.urls.get_image_url(image_name, True),
|
||||
self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||
image_record=image_record,
|
||||
image_url=self.__invoker.services.urls.get_image_url(image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -146,25 +154,22 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
try:
|
||||
image_record = self.__invoker.services.image_records.get(image_name)
|
||||
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
try:
|
||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||
return self.__invoker.services.image_records.get_metadata(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
|
||||
try:
|
||||
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
|
||||
if workflow_id is None:
|
||||
return None
|
||||
return self.__invoker.services.workflow_records.get(workflow_id)
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -215,10 +220,11 @@ class ImageService(ImageServiceABC):
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
image_record=r,
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
|
@ -27,6 +27,8 @@ if TYPE_CHECKING:
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@ -55,6 +57,8 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
workflow_image_records: "WorkflowImageRecordsStorageBase"
|
||||
workflow_records: "WorkflowRecordsStorageBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -80,6 +84,8 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_image_records: "WorkflowImageRecordsStorageBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@ -103,3 +109,5 @@ class InvocationServices:
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_image_records = workflow_image_records
|
||||
self.workflow_records = workflow_records
|
||||
|
@ -18,7 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: threading.RLock
|
||||
_adapter: Optional[TypeAdapter[T]]
|
||||
_validator: Optional[TypeAdapter[T]]
|
||||
|
||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
@ -28,7 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._cursor = self._conn.cursor()
|
||||
self._adapter: Optional[TypeAdapter[T]] = None
|
||||
self._validator: Optional[TypeAdapter[T]] = None
|
||||
|
||||
self._create_table()
|
||||
|
||||
@ -47,14 +47,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
if self._adapter is None:
|
||||
if self._validator is None:
|
||||
"""
|
||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||
we can create it when it is first needed instead.
|
||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
"""
|
||||
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._adapter.validate_json(item)
|
||||
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._validator.validate_json(item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
@ -147,20 +147,20 @@ DEFAULT_QUEUE_ID = "default"
|
||||
|
||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||
|
||||
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||
|
||||
|
||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||
field_values_raw = queue_item_dict.get("field_values", None)
|
||||
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||
|
||||
|
||||
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||
|
||||
|
||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
session_raw = queue_item_dict.get("session", "{}")
|
||||
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
||||
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
|
||||
return session
|
||||
|
||||
|
||||
|
@ -193,7 +193,7 @@ class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
graph: "Graph" = Field(description="The graph to run", default=None)
|
||||
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
@ -439,6 +439,14 @@ class Graph(BaseModel):
|
||||
except Exception as e:
|
||||
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
|
||||
|
||||
def _is_destination_field_Any(self, edge: Edge) -> bool:
|
||||
"""Checks if the destination field for an edge is of type typing.Any"""
|
||||
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
|
||||
|
||||
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
|
||||
"""Checks if the destination field for an edge is of type typing.Any"""
|
||||
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
|
||||
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
@ -491,8 +499,19 @@ class Graph(BaseModel):
|
||||
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
# Validate that we are not connecting collector to iterator (currently unsupported)
|
||||
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
|
||||
raise InvalidEdgeError(
|
||||
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
|
||||
if (
|
||||
isinstance(from_node, CollectInvocation)
|
||||
and edge.source.field == "collection"
|
||||
and not self._is_destination_field_list_of_Any(edge)
|
||||
and not self._is_destination_field_Any(edge)
|
||||
):
|
||||
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||
raise InvalidEdgeError(
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
@ -725,16 +744,15 @@ class Graph(BaseModel):
|
||||
# Get the input root type
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||
if not all(
|
||||
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
|
||||
for f in output_fields
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -0,0 +1,23 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WorkflowImageRecordsStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Creates a workflow-image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's workflow id, if it has one."""
|
||||
pass
|
@ -0,0 +1,122 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||
|
||||
|
||||
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
|
||||
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
|
||||
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
# Create the `workflow_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||
workflow_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between workflows and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for workflow id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for workflow id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Creates a workflow-image record."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflow_images (workflow_id, image_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(workflow_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_workflow_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's workflow id, if it has one."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id
|
||||
FROM workflow_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
0
invokeai/app/services/workflow_records/__init__.py
Normal file
0
invokeai/app/services/workflow_records/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||
|
||||
|
||||
class WorkflowRecordsStorageBase(ABC):
|
||||
"""Base class for workflow storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
"""Get workflow by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
"""Creates a workflow."""
|
||||
pass
|
@ -0,0 +1,2 @@
|
||||
class WorkflowNotFoundError(Exception):
|
||||
"""Raised when a workflow is not found"""
|
@ -0,0 +1,102 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
_invoker: Invoker
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow
|
||||
FROM workflows
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowFieldValidator.validate_json(row[0])
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
try:
|
||||
# workflows do not have ids until they are saved
|
||||
workflow_id = uuid_string()
|
||||
workflow.root["id"] = workflow_id
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflows(workflow)
|
||||
VALUES (?);
|
||||
""",
|
||||
(workflow.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow_id)
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
@ -59,6 +59,8 @@ export type AppConfig = {
|
||||
nodesAllowlist: string[] | undefined;
|
||||
nodesDenylist: string[] | undefined;
|
||||
maxUpscalePixels?: number;
|
||||
metadataFetchDebounce?: number;
|
||||
workflowFetchDebounce?: number;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: string[];
|
||||
|
@ -37,7 +37,12 @@ const useColorPicker = () => {
|
||||
1
|
||||
).data;
|
||||
|
||||
if (!(a && r && g && b)) {
|
||||
if (
|
||||
r === undefined ||
|
||||
g === undefined ||
|
||||
b === undefined ||
|
||||
a === undefined
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ import {
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@ -38,10 +38,9 @@ import {
|
||||
FaSeedling,
|
||||
} from 'react-icons/fa';
|
||||
import { FaCircleNodes, FaEllipsis } from 'react-icons/fa6';
|
||||
import {
|
||||
useGetImageDTOQuery,
|
||||
useGetImageMetadataFromFileQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
|
||||
import { menuListMotionProps } from 'theme/components/menu';
|
||||
import { sentImageToImg2Img } from '../../store/actions';
|
||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||
@ -89,7 +88,6 @@ const CurrentImageButtons = () => {
|
||||
shouldShowImageDetails,
|
||||
lastSelectedImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldFetchMetadataFromApi,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
@ -104,23 +102,12 @@ const CurrentImageButtons = () => {
|
||||
lastSelectedImage?.image_name ?? skipToken
|
||||
);
|
||||
|
||||
const getMetadataArg = useMemo(() => {
|
||||
if (lastSelectedImage) {
|
||||
return { image: lastSelectedImage, shouldFetchMetadataFromApi };
|
||||
} else {
|
||||
return skipToken;
|
||||
}
|
||||
}, [lastSelectedImage, shouldFetchMetadataFromApi]);
|
||||
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(
|
||||
lastSelectedImage?.image_name
|
||||
);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
getMetadataArg,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
const { workflow, isLoading: isLoadingWorkflow } = useDebouncedWorkflow(
|
||||
lastSelectedImage?.workflow_id
|
||||
);
|
||||
|
||||
const handleLoadWorkflow = useCallback(() => {
|
||||
@ -257,7 +244,7 @@ const CurrentImageButtons = () => {
|
||||
|
||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
isLoading={isLoadingWorkflow}
|
||||
icon={<FaCircleNodes />}
|
||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||
@ -265,7 +252,7 @@ const CurrentImageButtons = () => {
|
||||
onClick={handleLoadWorkflow}
|
||||
/>
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
isLoading={isLoadingMetadata}
|
||||
icon={<FaQuoteRight />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
@ -273,7 +260,7 @@ const CurrentImageButtons = () => {
|
||||
onClick={handleUsePrompt}
|
||||
/>
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
isLoading={isLoadingMetadata}
|
||||
icon={<FaSeedling />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
@ -281,7 +268,7 @@ const CurrentImageButtons = () => {
|
||||
onClick={handleUseSeed}
|
||||
/>
|
||||
<IAIIconButton
|
||||
isLoading={isLoading}
|
||||
isLoading={isLoadingMetadata}
|
||||
icon={<FaAsterisk />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
|
@ -2,7 +2,7 @@ import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
imagesToChangeSelected,
|
||||
@ -32,12 +32,12 @@ import {
|
||||
import { FaCircleNodes } from 'react-icons/fa6';
|
||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||
import {
|
||||
useGetImageMetadataFromFileQuery,
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
} from 'services/api/endpoints/images';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { configSelector } from '../../../system/store/configSelectors';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
@ -53,18 +53,13 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const toaster = useAppToaster();
|
||||
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
{ image: imageDTO, shouldFetchMetadataFromApi },
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(
|
||||
imageDTO?.image_name
|
||||
);
|
||||
const { workflow, isLoading: isLoadingWorkflow } = useDebouncedWorkflow(
|
||||
imageDTO?.workflow_id
|
||||
);
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
@ -181,17 +176,17 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
{t('parameters.downloadImage')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaCircleNodes />}
|
||||
icon={isLoadingWorkflow ? <SpinnerIcon /> : <FaCircleNodes />}
|
||||
onClickCapture={handleLoadWorkflow}
|
||||
isDisabled={isLoading || !workflow}
|
||||
isDisabled={isLoadingWorkflow || !workflow}
|
||||
>
|
||||
{t('nodes.loadWorkflow')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
|
||||
icon={isLoadingMetadata ? <SpinnerIcon /> : <FaQuoteRight />}
|
||||
onClickCapture={handleRecallPrompt}
|
||||
isDisabled={
|
||||
isLoading ||
|
||||
isLoadingMetadata ||
|
||||
(metadata?.positive_prompt === undefined &&
|
||||
metadata?.negative_prompt === undefined)
|
||||
}
|
||||
@ -199,16 +194,16 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
|
||||
icon={isLoadingMetadata ? <SpinnerIcon /> : <FaSeedling />}
|
||||
onClickCapture={handleRecallSeed}
|
||||
isDisabled={isLoading || metadata?.seed === undefined}
|
||||
isDisabled={isLoadingMetadata || metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
|
||||
icon={isLoadingMetadata ? <SpinnerIcon /> : <FaAsterisk />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={isLoading || !metadata}
|
||||
isDisabled={isLoadingMetadata || !metadata}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
|
@ -10,15 +10,14 @@ import {
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
|
||||
import { memo } from 'react';
|
||||
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import DataViewer from './DataViewer';
|
||||
import ImageMetadataActions from './ImageMetadataActions';
|
||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { configSelector } from '../../../system/store/configSelectors';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
@ -32,17 +31,8 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
// });
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
||||
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
||||
{ image, shouldFetchMetadataFromApi },
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
);
|
||||
const { metadata } = useDebouncedMetadata(image.image_name);
|
||||
const { workflow } = useDebouncedWorkflow(image.workflow_id);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { useWithWorkflow } from 'features/nodes/hooks/useWithWorkflow';
|
||||
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const hasImageOutput = useHasImageOutput(nodeId);
|
||||
const withWorkflow = useWithWorkflow(nodeId);
|
||||
const embedWorkflow = useEmbedWorkflow(nodeId);
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
@ -21,7 +21,7 @@ const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
|
||||
if (!hasImageOutput) {
|
||||
if (!withWorkflow) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import { memo } from 'react';
|
||||
import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus';
|
||||
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
||||
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
||||
import UseCacheCheckbox from './UseCacheCheckbox';
|
||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||
import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
|
@ -0,0 +1,31 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useWithWorkflow = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate) {
|
||||
return false;
|
||||
}
|
||||
return nodeTemplate.withWorkflow;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const withWorkflow = useAppSelector(selector);
|
||||
return withWorkflow;
|
||||
};
|
@ -69,6 +69,8 @@ export const validateSourceAndTargetTypes = (
|
||||
(sourceType === 'integer' || sourceType === 'float') &&
|
||||
targetType === 'string';
|
||||
|
||||
const isTargetAnyType = targetType === 'Any';
|
||||
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
@ -76,6 +78,7 @@ export const validateSourceAndTargetTypes = (
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat ||
|
||||
isIntOrFloatToString
|
||||
isIntOrFloatToString ||
|
||||
isTargetAnyType
|
||||
);
|
||||
};
|
||||
|
@ -33,6 +33,8 @@ export const COLLECTION_TYPES: FieldType[] = [
|
||||
'ColorCollection',
|
||||
'T2IAdapterCollection',
|
||||
'IPAdapterCollection',
|
||||
'MetadataItemCollection',
|
||||
'MetadataCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
@ -47,6 +49,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
'ColorPolymorphic',
|
||||
'T2IAdapterPolymorphic',
|
||||
'IPAdapterPolymorphic',
|
||||
'MetadataItemPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES: FieldType[] = [
|
||||
@ -78,6 +81,8 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
||||
ColorField: 'ColorCollection',
|
||||
T2IAdapterField: 'T2IAdapterCollection',
|
||||
IPAdapterField: 'IPAdapterCollection',
|
||||
MetadataItemField: 'MetadataItemCollection',
|
||||
MetadataField: 'MetadataCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
@ -97,6 +102,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
||||
ColorField: 'ColorPolymorphic',
|
||||
T2IAdapterField: 'T2IAdapterPolymorphic',
|
||||
IPAdapterField: 'IPAdapterPolymorphic',
|
||||
MetadataItemField: 'MetadataItemPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
@ -111,6 +117,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
ColorPolymorphic: 'ColorField',
|
||||
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||
IPAdapterPolymorphic: 'IPAdapterField',
|
||||
MetadataItemPolymorphic: 'MetadataItemField',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
@ -144,6 +151,37 @@ export const isPolymorphicItemType = (
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
Any: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Any',
|
||||
},
|
||||
MetadataField: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata dict.',
|
||||
title: 'Metadata Dict',
|
||||
},
|
||||
MetadataCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'A collection of metadata dicts.',
|
||||
title: 'Metadata Dict Collection',
|
||||
},
|
||||
MetadataItemField: {
|
||||
color: 'gray.500',
|
||||
description: 'A metadata item.',
|
||||
title: 'Metadata Item',
|
||||
},
|
||||
MetadataItemCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Metadata Item Collection',
|
||||
},
|
||||
MetadataItemPolymorphic: {
|
||||
color: 'gray.500',
|
||||
description:
|
||||
'MetadataItem or MetadataItemCollection field types are accepted.',
|
||||
title: 'Metadata Item Polymorphic',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.booleanDescription'),
|
||||
|
@ -54,6 +54,10 @@ export type InvocationTemplate = {
|
||||
* The type of this node's output
|
||||
*/
|
||||
outputType: string; // TODO: generate a union of output types
|
||||
/**
|
||||
* Whether or not this invocation supports workflows
|
||||
*/
|
||||
withWorkflow: boolean;
|
||||
/**
|
||||
* The invocation's version.
|
||||
*/
|
||||
@ -72,6 +76,7 @@ export type FieldUIConfig = {
|
||||
|
||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||
export const zFieldType = z.enum([
|
||||
'Any',
|
||||
'BoardField',
|
||||
'boolean',
|
||||
'BooleanCollection',
|
||||
@ -109,6 +114,11 @@ export const zFieldType = z.enum([
|
||||
'LatentsPolymorphic',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'MetadataField',
|
||||
'MetadataCollection',
|
||||
'MetadataItemField',
|
||||
'MetadataItemCollection',
|
||||
'MetadataItemPolymorphic',
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
'SDXLMainModelField',
|
||||
@ -685,6 +695,57 @@ export type CollectionItemInputFieldValue = z.infer<
|
||||
typeof zCollectionItemInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItemField = z.object({
|
||||
label: z.string(),
|
||||
value: z.any(),
|
||||
});
|
||||
export type MetadataItemField = z.infer<typeof zMetadataItemField>;
|
||||
|
||||
export const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItemField'),
|
||||
value: zMetadataItemField.optional(),
|
||||
});
|
||||
export type MetadataItemInputFieldValue = z.infer<
|
||||
typeof zMetadataItemInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItemCollectionInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItemCollection'),
|
||||
value: z.array(zMetadataItemField).optional(),
|
||||
});
|
||||
export type MetadataItemCollectionInputFieldValue = z.infer<
|
||||
typeof zMetadataItemCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItemPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItemPolymorphic'),
|
||||
value: z
|
||||
.union([zMetadataItemField, z.array(zMetadataItemField)])
|
||||
.optional(),
|
||||
});
|
||||
export type MetadataItemPolymorphicInputFieldValue = z.infer<
|
||||
typeof zMetadataItemPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataField = z.record(z.any());
|
||||
export type MetadataField = z.infer<typeof zMetadataField>;
|
||||
|
||||
export const zMetadataInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataField'),
|
||||
value: zMetadataField.optional(),
|
||||
});
|
||||
export type MetadataInputFieldValue = z.infer<typeof zMetadataInputFieldValue>;
|
||||
|
||||
export const zMetadataCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataCollection'),
|
||||
value: z.array(zMetadataField).optional(),
|
||||
});
|
||||
export type MetadataCollectionInputFieldValue = z.infer<
|
||||
typeof zMetadataCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zColorField = z.object({
|
||||
r: z.number().int().min(0).max(255),
|
||||
g: z.number().int().min(0).max(255),
|
||||
@ -723,7 +784,13 @@ export type SchedulerInputFieldValue = z.infer<
|
||||
typeof zSchedulerInputFieldValue
|
||||
>;
|
||||
|
||||
export const zAnyInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Any'),
|
||||
value: z.any().optional(),
|
||||
});
|
||||
|
||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zAnyInputFieldValue,
|
||||
zBoardInputFieldValue,
|
||||
zBooleanCollectionInputFieldValue,
|
||||
zBooleanInputFieldValue,
|
||||
@ -774,6 +841,11 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zUNetInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
zMetadataItemInputFieldValue,
|
||||
zMetadataItemCollectionInputFieldValue,
|
||||
zMetadataItemPolymorphicInputFieldValue,
|
||||
zMetadataInputFieldValue,
|
||||
zMetadataCollectionInputFieldValue,
|
||||
]);
|
||||
|
||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||
@ -786,6 +858,11 @@ export type InputFieldTemplateBase = {
|
||||
fieldKind: 'input';
|
||||
} & _InputField;
|
||||
|
||||
export type AnyInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'Any';
|
||||
default: undefined;
|
||||
};
|
||||
|
||||
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'integer';
|
||||
default: number;
|
||||
@ -939,6 +1016,11 @@ export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'UNetField';
|
||||
};
|
||||
|
||||
export type MetadataItemFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataItemField';
|
||||
};
|
||||
|
||||
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ClipField';
|
||||
@ -1087,6 +1169,34 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'WorkflowField';
|
||||
};
|
||||
|
||||
export type MetadataItemInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataItemField';
|
||||
};
|
||||
|
||||
export type MetadataItemCollectionInputFieldTemplate =
|
||||
InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataItemCollection';
|
||||
};
|
||||
|
||||
export type MetadataItemPolymorphicInputFieldTemplate = Omit<
|
||||
MetadataItemInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'MetadataItemPolymorphic';
|
||||
};
|
||||
|
||||
export type MetadataInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataField';
|
||||
};
|
||||
|
||||
export type MetadataCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataCollection';
|
||||
};
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
@ -1094,6 +1204,7 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| AnyInputFieldTemplate
|
||||
| BoardInputFieldTemplate
|
||||
| BooleanCollectionInputFieldTemplate
|
||||
| BooleanPolymorphicInputFieldTemplate
|
||||
@ -1143,7 +1254,12 @@ export type InputFieldTemplate =
|
||||
| T2IAdapterPolymorphicInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate;
|
||||
| VaeModelInputFieldTemplate
|
||||
| MetadataItemInputFieldTemplate
|
||||
| MetadataItemCollectionInputFieldTemplate
|
||||
| MetadataInputFieldTemplate
|
||||
| MetadataItemPolymorphicInputFieldTemplate
|
||||
| MetadataCollectionInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
@ -1264,7 +1380,7 @@ export const isInvocationFieldSchema = (
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
||||
const zLoRAMetadataItem = z.object({
|
||||
export const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
startCase,
|
||||
} from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
@ -15,36 +16,70 @@ import {
|
||||
isPolymorphicItemType,
|
||||
} from '../types/constants';
|
||||
import {
|
||||
AnyInputFieldTemplate,
|
||||
BoardInputFieldTemplate,
|
||||
BooleanCollectionInputFieldTemplate,
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
CollectionInputFieldTemplate,
|
||||
CollectionItemInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
ColorInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ConditioningField,
|
||||
ConditioningInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
IPAdapterCollectionInputFieldTemplate,
|
||||
IPAdapterField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
IPAdapterPolymorphicInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
ImageInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
InputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerCollectionInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
InvocationFieldSchema,
|
||||
InvocationSchemaObject,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
LatentsField,
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelInputFieldTemplate,
|
||||
MainModelInputFieldTemplate,
|
||||
MetadataCollectionInputFieldTemplate,
|
||||
MetadataInputFieldTemplate,
|
||||
MetadataItemCollectionInputFieldTemplate,
|
||||
MetadataItemInputFieldTemplate,
|
||||
MetadataItemPolymorphicInputFieldTemplate,
|
||||
OpenAPIV3_1SchemaOrRef,
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
StringCollectionInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
T2IAdapterCollectionInputFieldTemplate,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInputFieldTemplate,
|
||||
T2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterPolymorphicInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
@ -52,36 +87,7 @@ import {
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
IPAdapterField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
IPAdapterPolymorphicInputFieldTemplate,
|
||||
IPAdapterCollectionInputFieldTemplate,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInputFieldTemplate,
|
||||
T2IAdapterModelInputFieldTemplate,
|
||||
T2IAdapterPolymorphicInputFieldTemplate,
|
||||
T2IAdapterCollectionInputFieldTemplate,
|
||||
BoardInputFieldTemplate,
|
||||
InputFieldTemplate,
|
||||
OpenAPIV3_1SchemaOrRef,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
@ -851,6 +857,78 @@ const buildCollectionItemInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildAnyInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): AnyInputFieldTemplate => {
|
||||
const template: AnyInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'Any',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemInputFieldTemplate => {
|
||||
const template: MetadataItemInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItemField',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemCollectionInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemCollectionInputFieldTemplate => {
|
||||
const template: MetadataItemCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItemCollection',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemPolymorphicInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemPolymorphicInputFieldTemplate => {
|
||||
const template: MetadataItemPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItemPolymorphic',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataDictInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataInputFieldTemplate => {
|
||||
const template: MetadataInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataField',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataCollectionInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataCollectionInputFieldTemplate => {
|
||||
const template: MetadataCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataCollection',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -1012,6 +1090,7 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
[key in FieldType]?: (arg: BuildInputFieldArg) => InputFieldTemplate;
|
||||
} = {
|
||||
BoardField: buildBoardInputFieldTemplate,
|
||||
Any: buildAnyInputFieldTemplate,
|
||||
boolean: buildBooleanInputFieldTemplate,
|
||||
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||
@ -1047,6 +1126,11 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||
MetadataItemField: buildMetadataItemInputFieldTemplate,
|
||||
MetadataItemCollection: buildMetadataItemCollectionInputFieldTemplate,
|
||||
MetadataItemPolymorphic: buildMetadataItemPolymorphicInputFieldTemplate,
|
||||
MetadataField: buildMetadataDictInputFieldTemplate,
|
||||
MetadataCollection: buildMetadataCollectionInputFieldTemplate,
|
||||
MainModelField: buildMainModelInputFieldTemplate,
|
||||
Scheduler: buildSchedulerInputFieldTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||
|
@ -3,6 +3,7 @@ import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
const FIELD_VALUE_FALLBACK_MAP: {
|
||||
[key in FieldType]: InputFieldValue['value'];
|
||||
} = {
|
||||
Any: undefined,
|
||||
enum: '',
|
||||
BoardField: undefined,
|
||||
boolean: false,
|
||||
@ -38,6 +39,11 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
MetadataItemField: undefined,
|
||||
MetadataItemCollection: [],
|
||||
MetadataItemPolymorphic: undefined,
|
||||
MetadataField: undefined,
|
||||
MetadataCollection: [],
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
|
@ -1,45 +0,0 @@
|
||||
import * as png from '@stevebel/png';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
zWorkflow,
|
||||
} from 'features/nodes/types/types';
|
||||
import { get } from 'lodash-es';
|
||||
|
||||
export const getMetadataAndWorkflowFromImageBlob = async (
|
||||
image: Blob
|
||||
): Promise<ImageMetadataAndWorkflow> => {
|
||||
const data: ImageMetadataAndWorkflow = {};
|
||||
const buffer = await image.arrayBuffer();
|
||||
const text = png.decode(buffer).text;
|
||||
|
||||
const rawMetadata = get(text, 'invokeai_metadata');
|
||||
if (rawMetadata) {
|
||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||
if (metadataResult.success) {
|
||||
data.metadata = metadataResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(metadataResult.error) },
|
||||
'Problem reading metadata from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const rawWorkflow = get(text, 'invokeai_workflow');
|
||||
if (rawWorkflow) {
|
||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||
if (workflowResult.success) {
|
||||
data.workflow = workflowResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(workflowResult.error) },
|
||||
'Problem reading workflow from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
};
|
@ -5,14 +5,14 @@ import {
|
||||
CollectInvocation,
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
CoreMetadataInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CONTROL_NET_COLLECT,
|
||||
METADATA_ACCUMULATOR,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
state: RootState,
|
||||
@ -23,9 +23,11 @@ export const addControlNetToLinearGraph = (
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
// | MetadataAccumulatorInvocation
|
||||
// | undefined;
|
||||
|
||||
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
||||
|
||||
if (validControlNets.length) {
|
||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||
@ -99,15 +101,9 @@ export const addControlNetToLinearGraph = (
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
if (metadataAccumulator?.controlnets) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const controlField = omit(controlNetNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as ControlField;
|
||||
metadataAccumulator.controlnets.push(controlField);
|
||||
}
|
||||
controlNetMetadata.push(
|
||||
omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField
|
||||
);
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
@ -117,5 +113,6 @@ export const addControlNetToLinearGraph = (
|
||||
},
|
||||
});
|
||||
});
|
||||
upsertMetadata(graph, { controlnets: controlNetMetadata });
|
||||
}
|
||||
};
|
||||
|
@ -1,25 +1,25 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
DenoiseLatentsInvocation,
|
||||
ResizeLatentsInvocation,
|
||||
NoiseInvocation,
|
||||
LatentsToImageInvocation,
|
||||
Edge,
|
||||
LatentsToImageInvocation,
|
||||
NoiseInvocation,
|
||||
ResizeLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
DENOISE_LATENTS,
|
||||
NOISE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
LATENTS_TO_IMAGE_HRF,
|
||||
DENOISE_LATENTS_HRF,
|
||||
RESCALE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_IMAGE_HRF,
|
||||
MAIN_MODEL_LOADER,
|
||||
NOISE,
|
||||
NOISE_HRF,
|
||||
RESCALE_LATENTS,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
// Copy certain connections from previous DENOISE_LATENTS to new DENOISE_LATENTS_HRF.
|
||||
function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
|
||||
@ -71,10 +71,8 @@ export const addHrfToGraph = (
|
||||
}
|
||||
const log = logger('txt2img');
|
||||
|
||||
const { vae } = state.generation;
|
||||
const { vae, hrfWidth, hrfHeight, hrfStrength } = state.generation;
|
||||
const isAutoVae = !vae;
|
||||
const hrfWidth = state.generation.hrfWidth;
|
||||
const hrfHeight = state.generation.hrfHeight;
|
||||
|
||||
// Pre-existing (original) graph nodes.
|
||||
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as
|
||||
@ -116,7 +114,7 @@ export const addHrfToGraph = (
|
||||
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
|
||||
scheduler: originalDenoiseLatentsNode?.scheduler,
|
||||
steps: originalDenoiseLatentsNode?.steps,
|
||||
denoising_start: 1 - state.generation.hrfStrength,
|
||||
denoising_start: 1 - hrfStrength,
|
||||
denoising_end: 1,
|
||||
};
|
||||
|
||||
@ -221,16 +219,6 @@ export const addHrfToGraph = (
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF,
|
||||
field: 'metadata',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||
@ -243,5 +231,11 @@ export const addHrfToGraph = (
|
||||
}
|
||||
);
|
||||
|
||||
upsertMetadata(graph, {
|
||||
hrf_height: hrfHeight,
|
||||
hrf_width: hrfWidth,
|
||||
hrf_strength: hrfStrength,
|
||||
});
|
||||
|
||||
copyConnectionsToDenoiseLatentsHrf(graph);
|
||||
};
|
||||
|
@ -1,16 +1,18 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
IPAdapterInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
IPAdapterMetadataField,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
IP_ADAPTER_COLLECT,
|
||||
METADATA_ACCUMULATOR,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (
|
||||
state: RootState,
|
||||
@ -21,10 +23,6 @@ export const addIPAdapterToLinearGraph = (
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (validIPAdapters.length) {
|
||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||
const ipAdapterCollectNode: CollectInvocation = {
|
||||
@ -50,6 +48,7 @@ export const addIPAdapterToLinearGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
|
||||
validIPAdapters.forEach((ipAdapter) => {
|
||||
if (!ipAdapter.model) {
|
||||
@ -76,19 +75,13 @@ export const addIPAdapterToLinearGraph = (
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||
|
||||
if (metadataAccumulator?.ipAdapters) {
|
||||
const ipAdapterField = {
|
||||
image: {
|
||||
image_name: ipAdapter.controlImage,
|
||||
},
|
||||
weight,
|
||||
ip_adapter_model: model,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
};
|
||||
|
||||
metadataAccumulator.ipAdapters.push(ipAdapterField);
|
||||
}
|
||||
ipAdapterMetdata.push(
|
||||
omit(ipAdapterNode, [
|
||||
'id',
|
||||
'type',
|
||||
'is_intermediate',
|
||||
]) as IPAdapterMetadataField
|
||||
);
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
@ -98,5 +91,7 @@ export const addIPAdapterToLinearGraph = (
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
|
||||
}
|
||||
};
|
||||
|
@ -2,20 +2,20 @@ import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import {
|
||||
CoreMetadataInvocation,
|
||||
LoraLoaderInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
CANVAS_OUTPAINT_GRAPH,
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CLIP_SKIP,
|
||||
LORA_LOADER,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addLoRAsToGraph = (
|
||||
state: RootState,
|
||||
@ -33,11 +33,11 @@ export const addLoRAsToGraph = (
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (loraCount > 0) {
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
@ -51,11 +51,11 @@ export const addLoRAsToGraph = (
|
||||
(e) =>
|
||||
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
|
||||
);
|
||||
}
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
forEach(loras, (lora) => {
|
||||
const { model_name, base_model, weight } = lora;
|
||||
@ -69,13 +69,10 @@ export const addLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
// add the lora to the metadata accumulator
|
||||
if (metadataAccumulator?.loras) {
|
||||
metadataAccumulator.loras.push({
|
||||
loraMetadata.push({
|
||||
lora: { model_name, base_model },
|
||||
weight,
|
||||
});
|
||||
}
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
@ -182,4 +179,6 @@ export const addLoRAsToGraph = (
|
||||
lastLoraNodeId = currentLoraNodeId;
|
||||
currentLoraIndex += 1;
|
||||
});
|
||||
|
||||
upsertMetadata(graph, { loras: loraMetadata });
|
||||
};
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import {
|
||||
MetadataAccumulatorInvocation,
|
||||
SDXLLoraLoaderInvocation,
|
||||
} from 'services/api/types';
|
||||
LoRAMetadataItem,
|
||||
NonNullableGraph,
|
||||
zLoRAMetadataItem,
|
||||
} from 'features/nodes/types/types';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
LORA_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
POSITIVE_CONDITIONING,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
@ -17,6 +17,7 @@ import {
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLLoRAsToGraph = (
|
||||
state: RootState,
|
||||
@ -34,9 +35,12 @@ export const addSDXLLoRAsToGraph = (
|
||||
|
||||
const { loras } = state.lora;
|
||||
const loraCount = size(loras);
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (loraCount === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const loraMetadata: LoRAMetadataItem[] = [];
|
||||
|
||||
// Handle Seamless Plugs
|
||||
const unetLoaderId = modelLoaderNodeId;
|
||||
@ -47,7 +51,6 @@ export const addSDXLLoRAsToGraph = (
|
||||
clipLoaderId = SDXL_MODEL_LOADER;
|
||||
}
|
||||
|
||||
if (loraCount > 0) {
|
||||
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
@ -57,12 +60,8 @@ export const addSDXLLoRAsToGraph = (
|
||||
!(
|
||||
e.source.node_id === clipLoaderId && ['clip'].includes(e.source.field)
|
||||
) &&
|
||||
!(
|
||||
e.source.node_id === clipLoaderId &&
|
||||
['clip2'].includes(e.source.field)
|
||||
)
|
||||
!(e.source.node_id === clipLoaderId && ['clip2'].includes(e.source.field))
|
||||
);
|
||||
}
|
||||
|
||||
// we need to remember the last lora so we can chain from it
|
||||
let lastLoraNodeId = '';
|
||||
@ -80,16 +79,12 @@ export const addSDXLLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
// add the lora to the metadata accumulator
|
||||
if (metadataAccumulator) {
|
||||
if (!metadataAccumulator.loras) {
|
||||
metadataAccumulator.loras = [];
|
||||
}
|
||||
metadataAccumulator.loras.push({
|
||||
loraMetadata.push(
|
||||
zLoRAMetadataItem.parse({
|
||||
lora: { model_name, base_model },
|
||||
weight,
|
||||
});
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
@ -242,4 +237,6 @@ export const addSDXLLoRAsToGraph = (
|
||||
lastLoraNodeId = currentLoraNodeId;
|
||||
currentLoraIndex += 1;
|
||||
});
|
||||
|
||||
upsertMetadata(graph, { loras: loraMetadata });
|
||||
};
|
||||
|
@ -2,7 +2,6 @@ import { RootState } from 'app/store/store';
|
||||
import {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageDTO,
|
||||
MetadataAccumulatorInvocation,
|
||||
SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
@ -12,7 +11,6 @@ import {
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_COMBINE,
|
||||
MASK_RESIZE_UP,
|
||||
METADATA_ACCUMULATOR,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||
@ -26,6 +24,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
state: RootState,
|
||||
@ -58,21 +57,15 @@ export const addSDXLRefinerToGraph = (
|
||||
return;
|
||||
}
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.refiner_model = refinerModel;
|
||||
metadataAccumulator.refiner_positive_aesthetic_score =
|
||||
refinerPositiveAestheticScore;
|
||||
metadataAccumulator.refiner_negative_aesthetic_score =
|
||||
refinerNegativeAestheticScore;
|
||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||
metadataAccumulator.refiner_start = refinerStart;
|
||||
metadataAccumulator.refiner_steps = refinerSteps;
|
||||
}
|
||||
upsertMetadata(graph, {
|
||||
refiner_model: refinerModel,
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
refiner_scheduler: refinerScheduler,
|
||||
refiner_start: refinerStart,
|
||||
refiner_steps: refinerSteps,
|
||||
});
|
||||
|
||||
const modelLoaderId = modelLoaderNodeId
|
||||
? modelLoaderNodeId
|
||||
|
@ -1,19 +1,15 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { SaveImageInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_IMAGE_HRF,
|
||||
METADATA_ACCUMULATOR,
|
||||
NSFW_CHECKER,
|
||||
SAVE_IMAGE,
|
||||
WATERMARKER,
|
||||
} from './constants';
|
||||
import {
|
||||
MetadataAccumulatorInvocation,
|
||||
SaveImageInvocation,
|
||||
} from 'services/api/types';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
/**
|
||||
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
|
||||
@ -37,23 +33,6 @@ export const addSaveImageNode = (
|
||||
|
||||
graph.nodes[SAVE_IMAGE] = saveImageNode;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: SAVE_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const destination = {
|
||||
node_id: SAVE_IMAGE,
|
||||
field: 'image',
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { SeamlessModeInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
@ -31,6 +32,17 @@ export const addSeamlessToLinearGraph = (
|
||||
seamless_y: seamlessYAxis,
|
||||
} as SeamlessModeInvocation;
|
||||
|
||||
if (seamlessXAxis) {
|
||||
upsertMetadata(graph, {
|
||||
seamless_x: seamlessXAxis,
|
||||
});
|
||||
}
|
||||
if (seamlessYAxis) {
|
||||
upsertMetadata(graph, {
|
||||
seamless_y: seamlessYAxis,
|
||||
});
|
||||
}
|
||||
|
||||
let denoisingNodeId = DENOISE_LATENTS;
|
||||
|
||||
if (
|
||||
|
@ -3,15 +3,15 @@ import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAd
|
||||
import { omit } from 'lodash-es';
|
||||
import {
|
||||
CollectInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
CoreMetadataInvocation,
|
||||
T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
METADATA_ACCUMULATOR,
|
||||
T2I_ADAPTER_COLLECT,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = (
|
||||
state: RootState,
|
||||
@ -22,10 +22,6 @@ export const addT2IAdaptersToLinearGraph = (
|
||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||
);
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (validT2IAdapters.length) {
|
||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||
const t2iAdapterCollectNode: CollectInvocation = {
|
||||
@ -51,6 +47,7 @@ export const addT2IAdaptersToLinearGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
|
||||
validT2IAdapters.forEach((t2iAdapter) => {
|
||||
if (!t2iAdapter.model) {
|
||||
@ -96,15 +93,13 @@ export const addT2IAdaptersToLinearGraph = (
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
||||
|
||||
if (metadataAccumulator?.t2iAdapters) {
|
||||
// metadata accumulator only needs a control field - not the whole node
|
||||
// extract what we need and add to the accumulator
|
||||
const t2iAdapterField = omit(t2iAdapterNode, [
|
||||
t2iAdapterMetdata.push(
|
||||
omit(t2iAdapterNode, [
|
||||
'id',
|
||||
'type',
|
||||
]) as T2IAdapterField;
|
||||
metadataAccumulator.t2iAdapters.push(t2iAdapterField);
|
||||
}
|
||||
'is_intermediate',
|
||||
]) as T2IAdapterField
|
||||
);
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
@ -114,5 +109,7 @@ export const addT2IAdaptersToLinearGraph = (
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetdata });
|
||||
}
|
||||
};
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -14,7 +13,6 @@ import {
|
||||
INPAINT_IMAGE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
ONNX_MODEL_LOADER,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
@ -26,6 +24,7 @@ import {
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
state: RootState,
|
||||
@ -41,9 +40,6 @@ export const addVAEToGraph = (
|
||||
);
|
||||
|
||||
const isAutoVae = !vae;
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (!isAutoVae) {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
@ -181,7 +177,7 @@ export const addVAEToGraph = (
|
||||
}
|
||||
}
|
||||
|
||||
if (vae && metadataAccumulator) {
|
||||
metadataAccumulator.vae = vae;
|
||||
if (vae) {
|
||||
upsertMetadata(graph, { vae });
|
||||
}
|
||||
};
|
||||
|
@ -5,14 +5,8 @@ import {
|
||||
ImageNSFWBlurInvocation,
|
||||
ImageWatermarkInvocation,
|
||||
LatentsToImageInvocation,
|
||||
MetadataAccumulatorInvocation,
|
||||
} from 'services/api/types';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NSFW_CHECKER,
|
||||
WATERMARKER,
|
||||
} from './constants';
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
|
||||
|
||||
export const addWatermarkerToGraph = (
|
||||
state: RootState,
|
||||
@ -32,10 +26,6 @@ export const addWatermarkerToGraph = (
|
||||
| ImageNSFWBlurInvocation
|
||||
| undefined;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (!nodeToAddTo) {
|
||||
// something has gone terribly awry
|
||||
return;
|
||||
@ -80,17 +70,4 @@ export const addWatermarkerToGraph = (
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (metadataAccumulator) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: WATERMARKER,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -1,12 +1,13 @@
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
||||
import {
|
||||
Graph,
|
||||
ESRGANInvocation,
|
||||
Graph,
|
||||
SaveImageInvocation,
|
||||
} from 'services/api/types';
|
||||
import { REALESRGAN as ESRGAN, SAVE_IMAGE } from './constants';
|
||||
import { BoardId } from 'features/gallery/store/types';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
type Arg = {
|
||||
image_name: string;
|
||||
@ -55,5 +56,9 @@ export const buildAdHocUpscaleGraph = ({
|
||||
],
|
||||
};
|
||||
|
||||
addCoreMetadataNode(graph, {
|
||||
esrgan_model: esrganModelName,
|
||||
});
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -20,12 +20,12 @@ import {
|
||||
IMG2IMG_RESIZE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
@ -308,10 +308,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@ -325,15 +322,10 @@ export const buildCanvasImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
};
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
|
@ -16,7 +16,6 @@ import {
|
||||
IMAGE_TO_LATENTS,
|
||||
IMG2IMG_RESIZE,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -28,6 +27,7 @@ import {
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
@ -319,10 +319,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@ -336,24 +333,8 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
strength,
|
||||
init_image: initialImage.image_name,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -18,7 +18,6 @@ import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
ONNX_MODEL_LOADER,
|
||||
@ -30,6 +29,7 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@ -301,10 +301,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@ -318,22 +315,6 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -21,13 +21,13 @@ import {
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
ONNX_MODEL_LOADER,
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@ -289,10 +289,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||
@ -306,23 +303,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [],
|
||||
clip_skip: clipSkip,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -2,15 +2,16 @@ import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { generateSeeds } from 'common/util/generateSeeds';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { range, unset } from 'lodash-es';
|
||||
import { range } from 'lodash-es';
|
||||
import { components } from 'services/api/schema';
|
||||
import { Batch, BatchConfig } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
METADATA_ACCUMULATOR,
|
||||
METADATA,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
import { getHasMetadata, removeMetadata } from './metadata';
|
||||
|
||||
export const prepareLinearUIBatch = (
|
||||
state: RootState,
|
||||
@ -24,7 +25,6 @@ export const prepareLinearUIBatch = (
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
if (prompts.length === 1) {
|
||||
unset(graph.nodes[METADATA_ACCUMULATOR], 'seed');
|
||||
const seeds = generateSeeds({
|
||||
count: iterations,
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
@ -40,9 +40,11 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
if (getHasMetadata(graph)) {
|
||||
// add to metadata
|
||||
removeMetadata(graph, 'seed');
|
||||
zipped.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
node_path: METADATA,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
@ -77,9 +79,11 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'seed');
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
node_path: METADATA,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
@ -106,13 +110,17 @@ export const prepareLinearUIBatch = (
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'seed');
|
||||
secondBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
node_path: METADATA,
|
||||
field_name: 'seed',
|
||||
items: seeds,
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||
secondBatchDatumList.push({
|
||||
node_path: CANVAS_COHERENCE_NOISE,
|
||||
@ -137,17 +145,17 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'positive_prompt');
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
node_path: METADATA,
|
||||
field_name: 'positive_prompt',
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
|
||||
unset(graph.nodes[METADATA_ACCUMULATOR], 'positive_style_prompt');
|
||||
|
||||
const stylePrompts = extendedPrompts.map((p) =>
|
||||
[p, positiveStylePrompt].join(' ')
|
||||
);
|
||||
@ -160,11 +168,13 @@ export const prepareLinearUIBatch = (
|
||||
});
|
||||
}
|
||||
|
||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
||||
// add to metadata
|
||||
if (getHasMetadata(graph)) {
|
||||
removeMetadata(graph, 'positive_style_prompt');
|
||||
firstBatchDatumList.push({
|
||||
node_path: METADATA_ACCUMULATOR,
|
||||
node_path: METADATA,
|
||||
field_name: 'positive_style_prompt',
|
||||
items: stylePrompts,
|
||||
items: extendedPrompts,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -21,13 +21,13 @@ import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RESIZE,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
@ -311,10 +311,7 @@ export const buildLinearImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@ -326,25 +323,9 @@ export const buildLinearImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
|
||||
clip_skip: clipSkip,
|
||||
strength,
|
||||
init_image: initialImage.imageName,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -18,7 +18,6 @@ import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -30,6 +29,7 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
@ -331,10 +331,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'sdxl_img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@ -346,26 +343,10 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
ipAdapters: [],
|
||||
t2iAdapters: [],
|
||||
strength: strength,
|
||||
strength,
|
||||
init_image: initialImage.imageName,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -11,9 +11,9 @@ import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
@ -225,10 +225,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'sdxl_txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@ -240,24 +237,8 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
ipAdapters: [],
|
||||
t2iAdapters: [],
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -15,12 +15,12 @@ import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||
import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import {
|
||||
CLIP_SKIP,
|
||||
DENOISE_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
ONNX_MODEL_LOADER,
|
||||
@ -48,10 +48,6 @@ export const buildLinearTextToImageGraph = (
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
seed,
|
||||
hrfWidth,
|
||||
hrfHeight,
|
||||
hrfStrength,
|
||||
hrfEnabled: hrfEnabled,
|
||||
} = state.generation;
|
||||
|
||||
const use_cpu = shouldUseCpuNoise;
|
||||
@ -238,10 +234,7 @@ export const buildLinearTextToImageGraph = (
|
||||
],
|
||||
};
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
addCoreMetadataNode(graph, {
|
||||
generation_mode: 'txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
@ -253,26 +246,7 @@ export const buildLinearTextToImageGraph = (
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined, // option; set in addVAEToGraph
|
||||
controlnets: [], // populated in addControlNetToLinearGraph
|
||||
loras: [], // populated in addLoRAsToGraph
|
||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
|
||||
clip_skip: clipSkip,
|
||||
hrf_width: hrfEnabled ? hrfWidth : undefined,
|
||||
hrf_height: hrfEnabled ? hrfHeight : undefined,
|
||||
hrf_strength: hrfEnabled ? hrfStrength : undefined,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Seamless To Graph
|
||||
|
@ -35,7 +35,6 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const { nodes, edges } = nodesState;
|
||||
|
||||
const filteredNodes = nodes.filter(isInvocationNode);
|
||||
const workflowJSON = JSON.stringify(buildWorkflow(nodesState));
|
||||
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
|
||||
@ -68,7 +67,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
|
||||
if (embedWorkflow) {
|
||||
// add the workflow to the node
|
||||
Object.assign(graphNode, { workflow: workflowJSON });
|
||||
Object.assign(graphNode, { workflow: buildWorkflow(nodesState) });
|
||||
}
|
||||
|
||||
// Add it to the nodes object
|
||||
|
@ -56,7 +56,14 @@ export const IP_ADAPTER = 'ip_adapter';
|
||||
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||
export const IMAGE_COLLECTION = 'image_collection';
|
||||
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
|
||||
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
||||
export const METADATA = 'core_metadata';
|
||||
export const BATCH_METADATA = 'batch_metadata';
|
||||
export const BATCH_METADATA_COLLECT = 'batch_metadata_collect';
|
||||
export const BATCH_SEED = 'batch_seed';
|
||||
export const BATCH_PROMPT = 'batch_prompt';
|
||||
export const BATCH_STYLE_PROMPT = 'batch_style_prompt';
|
||||
export const METADATA_COLLECT = 'metadata_collect';
|
||||
export const MERGE_METADATA = 'merge_metadata';
|
||||
export const REALESRGAN = 'esrgan';
|
||||
export const DIVIDE = 'divide';
|
||||
export const SCALE = 'scale_image';
|
||||
|
@ -0,0 +1,66 @@
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { CoreMetadataInvocation } from 'services/api/types';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { METADATA, SAVE_IMAGE } from './constants';
|
||||
|
||||
export const addCoreMetadataNode = (
|
||||
graph: NonNullableGraph,
|
||||
metadata: Partial<CoreMetadataInvocation> | JsonObject
|
||||
): void => {
|
||||
graph.nodes[METADATA] = {
|
||||
id: METADATA,
|
||||
type: 'core_metadata',
|
||||
...metadata,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: SAVE_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
export const upsertMetadata = (
|
||||
graph: NonNullableGraph,
|
||||
metadata: Partial<CoreMetadataInvocation> | JsonObject
|
||||
): void => {
|
||||
const metadataNode = graph.nodes[METADATA] as
|
||||
| CoreMetadataInvocation
|
||||
| undefined;
|
||||
|
||||
if (!metadataNode) {
|
||||
return;
|
||||
}
|
||||
|
||||
Object.assign(metadataNode, metadata);
|
||||
};
|
||||
|
||||
export const removeMetadata = (
|
||||
graph: NonNullableGraph,
|
||||
key: keyof CoreMetadataInvocation
|
||||
): void => {
|
||||
const metadataNode = graph.nodes[METADATA] as
|
||||
| CoreMetadataInvocation
|
||||
| undefined;
|
||||
|
||||
if (!metadataNode) {
|
||||
return;
|
||||
}
|
||||
|
||||
delete metadataNode[key];
|
||||
};
|
||||
|
||||
export const getHasMetadata = (graph: NonNullableGraph): boolean => {
|
||||
const metadataNode = graph.nodes[METADATA] as
|
||||
| CoreMetadataInvocation
|
||||
| undefined;
|
||||
|
||||
return Boolean(metadataNode);
|
||||
};
|
@ -4,7 +4,6 @@ import { reduce, startCase } from 'lodash-es';
|
||||
import { OpenAPIV3_1 } from 'openapi-types';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import {
|
||||
FieldType,
|
||||
InputFieldTemplate,
|
||||
InvocationSchemaObject,
|
||||
InvocationTemplate,
|
||||
@ -16,18 +15,11 @@ import {
|
||||
} from '../types/types';
|
||||
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata', 'use_cache'];
|
||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||
const RESERVED_FIELD_TYPES = [
|
||||
'WorkflowField',
|
||||
'MetadataField',
|
||||
'IsIntermediate',
|
||||
];
|
||||
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
|
||||
|
||||
const invocationDenylist: AnyInvocationType[] = [
|
||||
'graph',
|
||||
'metadata_accumulator',
|
||||
];
|
||||
const invocationDenylist: AnyInvocationType[] = ['graph'];
|
||||
|
||||
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||
@ -42,7 +34,7 @@ const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||
return false;
|
||||
};
|
||||
|
||||
const isReservedFieldType = (fieldType: FieldType) => {
|
||||
const isReservedFieldType = (fieldType: string) => {
|
||||
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
|
||||
return true;
|
||||
}
|
||||
@ -86,6 +78,7 @@ export const parseSchema = (
|
||||
const tags = schema.tags ?? [];
|
||||
const description = schema.description ?? '';
|
||||
const version = schema.version;
|
||||
let withWorkflow = false;
|
||||
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
@ -112,7 +105,7 @@ export const parseSchema = (
|
||||
|
||||
const fieldType = property.ui_type ?? getFieldType(property);
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
if (!fieldType) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
@ -120,11 +113,16 @@ export const parseSchema = (
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Skipping unknown input field type'
|
||||
'Missing input field type'
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (fieldType === 'WorkflowField') {
|
||||
withWorkflow = true;
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (isReservedFieldType(fieldType)) {
|
||||
logger('nodes').trace(
|
||||
{
|
||||
@ -133,7 +131,20 @@ export const parseSchema = (
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
'Skipping reserved field type'
|
||||
`Skipping reserved input field type: ${fieldType}`
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
if (!isFieldType(fieldType)) {
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
fieldType,
|
||||
field: parseify(property),
|
||||
},
|
||||
`Skipping unknown input field type: ${fieldType}`
|
||||
);
|
||||
return inputsAccumulator;
|
||||
}
|
||||
@ -146,7 +157,7 @@ export const parseSchema = (
|
||||
);
|
||||
|
||||
if (!field) {
|
||||
logger('nodes').debug(
|
||||
logger('nodes').warn(
|
||||
{
|
||||
node: type,
|
||||
fieldName: propertyName,
|
||||
@ -248,6 +259,7 @@ export const parseSchema = (
|
||||
inputs,
|
||||
outputs,
|
||||
useCache,
|
||||
withWorkflow,
|
||||
};
|
||||
|
||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||
|
@ -1,20 +1,15 @@
|
||||
import { EntityState, Update } from '@reduxjs/toolkit';
|
||||
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
||||
import { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
ASSETS_CATEGORIES,
|
||||
BoardId,
|
||||
IMAGE_CATEGORIES,
|
||||
IMAGE_LIMIT,
|
||||
} from 'features/gallery/store/types';
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
} from 'features/nodes/types/types';
|
||||
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
|
||||
import { CoreMetadata, zCoreMetadata } from 'features/nodes/types/types';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import { ApiTagDescription, LIST_TAG, api } from '..';
|
||||
import { $authToken, $projectId } from '../client';
|
||||
import { components, paths } from '../schema';
|
||||
import {
|
||||
DeleteBoardResult,
|
||||
@ -23,7 +18,6 @@ import {
|
||||
ListImagesArgs,
|
||||
OffsetPaginatedResults_ImageDTO_,
|
||||
PostUploadAction,
|
||||
UnsafeImageMetadata,
|
||||
} from '../types';
|
||||
import {
|
||||
getCategories,
|
||||
@ -114,73 +108,24 @@ export const imagesApi = api.injectEndpoints({
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadata: build.query<UnsafeImageMetadata, string>({
|
||||
getImageMetadata: build.query<CoreMetadata | undefined, string>({
|
||||
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
|
||||
providesTags: (result, error, image_name) => [
|
||||
{ type: 'ImageMetadata', id: image_name },
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
getImageMetadataFromFile: build.query<
|
||||
ImageMetadataAndWorkflow,
|
||||
{ image: ImageDTO; shouldFetchMetadataFromApi: boolean }
|
||||
>({
|
||||
queryFn: async (
|
||||
args: { image: ImageDTO; shouldFetchMetadataFromApi: boolean },
|
||||
api,
|
||||
extraOptions,
|
||||
fetchWithBaseQuery
|
||||
transformResponse: (
|
||||
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
|
||||
) => {
|
||||
if (args.shouldFetchMetadataFromApi) {
|
||||
let metadata;
|
||||
const metadataResponse = await fetchWithBaseQuery(
|
||||
`images/i/${args.image.image_name}/metadata`
|
||||
);
|
||||
if (metadataResponse.data) {
|
||||
const metadataResult = zCoreMetadata.safeParse(
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(metadataResponse.data as any)?.metadata
|
||||
);
|
||||
if (metadataResult.success) {
|
||||
metadata = metadataResult.data;
|
||||
}
|
||||
}
|
||||
return { data: { metadata } };
|
||||
if (response) {
|
||||
const result = zCoreMetadata.safeParse(response);
|
||||
if (result.success) {
|
||||
return result.data;
|
||||
} else {
|
||||
const authToken = $authToken.get();
|
||||
const projectId = $projectId.get();
|
||||
const customBaseQuery = fetchBaseQuery({
|
||||
baseUrl: '',
|
||||
prepareHeaders: (headers) => {
|
||||
if (authToken) {
|
||||
headers.set('Authorization', `Bearer ${authToken}`);
|
||||
logger('images').warn('Problem parsing metadata');
|
||||
}
|
||||
if (projectId) {
|
||||
headers.set('project-id', projectId);
|
||||
}
|
||||
|
||||
return headers;
|
||||
return;
|
||||
},
|
||||
responseHandler: async (res) => {
|
||||
return await res.blob();
|
||||
},
|
||||
});
|
||||
|
||||
const response = await customBaseQuery(
|
||||
args.image.image_url,
|
||||
api,
|
||||
extraOptions
|
||||
);
|
||||
const data = await getMetadataAndWorkflowFromImageBlob(
|
||||
response.data as Blob
|
||||
);
|
||||
|
||||
return { data };
|
||||
}
|
||||
},
|
||||
providesTags: (result, error, { image }) => [
|
||||
{ type: 'ImageMetadataFromFile', id: image.image_name },
|
||||
],
|
||||
keepUnusedDataFor: 86400, // 24 hours
|
||||
}),
|
||||
deleteImage: build.mutation<void, ImageDTO>({
|
||||
@ -1629,6 +1574,5 @@ export const {
|
||||
useDeleteBoardMutation,
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
useGetImageMetadataFromFileQuery,
|
||||
useBulkDownloadImagesMutation,
|
||||
} = imagesApi;
|
||||
|
@ -0,0 +1,30 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { Workflow, zWorkflow } from 'features/nodes/types/types';
|
||||
import { api } from '..';
|
||||
import { paths } from '../schema';
|
||||
|
||||
export const workflowsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getWorkflow: build.query<Workflow | undefined, string>({
|
||||
query: (workflow_id) => `workflows/i/${workflow_id}`,
|
||||
providesTags: (result, error, workflow_id) => [
|
||||
{ type: 'Workflow', id: workflow_id },
|
||||
],
|
||||
transformResponse: (
|
||||
response: paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json']
|
||||
) => {
|
||||
if (response) {
|
||||
const result = zWorkflow.safeParse(response);
|
||||
if (result.success) {
|
||||
return result.data;
|
||||
} else {
|
||||
logger('images').warn('Problem parsing workflow');
|
||||
}
|
||||
}
|
||||
return;
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
export const { useGetWorkflowQuery } = workflowsApi;
|
@ -0,0 +1,21 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { useGetImageMetadataQuery } from '../endpoints/images';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
export const useDebouncedMetadata = (imageName?: string | null) => {
|
||||
const metadataFetchDebounce = useAppSelector(
|
||||
(state) => state.config.metadataFetchDebounce
|
||||
);
|
||||
|
||||
const [debouncedImageName] = useDebounce(
|
||||
imageName,
|
||||
metadataFetchDebounce ?? 0
|
||||
);
|
||||
|
||||
const { data: metadata, isLoading } = useGetImageMetadataQuery(
|
||||
debouncedImageName ?? skipToken
|
||||
);
|
||||
|
||||
return { metadata, isLoading };
|
||||
};
|
@ -0,0 +1,21 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
import { useGetWorkflowQuery } from '../endpoints/workflows';
|
||||
|
||||
export const useDebouncedWorkflow = (workflowId?: string | null) => {
|
||||
const workflowFetchDebounce = useAppSelector(
|
||||
(state) => state.config.workflowFetchDebounce
|
||||
);
|
||||
|
||||
const [debouncedWorkflowID] = useDebounce(
|
||||
workflowId,
|
||||
workflowFetchDebounce ?? 0
|
||||
);
|
||||
|
||||
const { data: workflow, isLoading } = useGetWorkflowQuery(
|
||||
debouncedWorkflowID ?? skipToken
|
||||
);
|
||||
|
||||
return { workflow, isLoading };
|
||||
};
|
@ -37,6 +37,7 @@ export const tagTypes = [
|
||||
'ControlNetModel',
|
||||
'LoRAModel',
|
||||
'SDXLRefinerModel',
|
||||
'Workflow',
|
||||
] as const;
|
||||
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
|
||||
export const LIST_TAG = 'LIST';
|
||||
|
1731
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
1731
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -27,14 +27,6 @@ export type BatchConfig =
|
||||
|
||||
export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult'];
|
||||
|
||||
/**
|
||||
* This is an unsafe type; the object inside is not guaranteed to be valid.
|
||||
*/
|
||||
export type UnsafeImageMetadata = {
|
||||
metadata: s['CoreMetadata'];
|
||||
graph: NonNullable<s['Graph']>;
|
||||
};
|
||||
|
||||
export type _InputField = s['_InputField'];
|
||||
export type _OutputField = s['_OutputField'];
|
||||
|
||||
@ -50,7 +42,6 @@ export type ImageChanges = s['ImageRecordChanges'];
|
||||
export type ImageCategory = s['ImageCategory'];
|
||||
export type ResourceOrigin = s['ResourceOrigin'];
|
||||
export type ImageField = s['ImageField'];
|
||||
export type ImageMetadata = s['ImageMetadata'];
|
||||
export type OffsetPaginatedResults_BoardDTO_ =
|
||||
s['OffsetPaginatedResults_BoardDTO_'];
|
||||
export type OffsetPaginatedResults_ImageDTO_ =
|
||||
@ -145,13 +136,19 @@ export type ImageCollectionInvocation = s['ImageCollectionInvocation'];
|
||||
export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
|
||||
export type OnnxModelLoaderInvocation = s['OnnxModelLoaderInvocation'];
|
||||
export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
|
||||
export type MetadataAccumulatorInvocation = s['MetadataAccumulatorInvocation'];
|
||||
export type ESRGANInvocation = s['ESRGANInvocation'];
|
||||
export type DivideInvocation = s['DivideInvocation'];
|
||||
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
|
||||
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
|
||||
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
|
||||
export type SaveImageInvocation = s['SaveImageInvocation'];
|
||||
export type MetadataInvocation = s['MetadataInvocation'];
|
||||
export type CoreMetadataInvocation = s['CoreMetadataInvocation'];
|
||||
export type MetadataItemInvocation = s['MetadataItemInvocation'];
|
||||
export type MergeMetadataInvocation = s['MergeMetadataInvocation'];
|
||||
export type IPAdapterMetadataField = s['IPAdapterMetadataField'];
|
||||
export type T2IAdapterField = s['T2IAdapterField'];
|
||||
export type LoRAMetadataField = s['LoRAMetadataField'];
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = s['ControlNetInvocation'];
|
||||
|
@ -75,6 +75,8 @@ def mock_services() -> InvocationServices:
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
workflow_image_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -80,6 +80,8 @@ def mock_services() -> InvocationServices:
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
workflow_image_records=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -10,7 +10,12 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
IntegerInvocation,
|
||||
StringInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.shared.graph import (
|
||||
@ -27,8 +32,11 @@ from invokeai.app.services.shared.graph import (
|
||||
)
|
||||
|
||||
from .test_nodes import (
|
||||
AnyTypeTestInvocation,
|
||||
ImageToImageTestInvocation,
|
||||
ListPassThroughInvocation,
|
||||
PolymorphicStringTestInvocation,
|
||||
PromptCollectionTestInvocation,
|
||||
PromptTestInvocation,
|
||||
TextToImageTestInvocation,
|
||||
)
|
||||
@ -607,8 +615,8 @@ def test_graph_can_deserialize():
|
||||
g.add_edge(e)
|
||||
|
||||
json = g.model_dump_json()
|
||||
adapter_graph = TypeAdapter(Graph)
|
||||
g2 = adapter_graph.validate_json(json)
|
||||
GraphValidator = TypeAdapter(Graph)
|
||||
g2 = GraphValidator.validate_json(json)
|
||||
|
||||
assert g2 is not None
|
||||
assert g2.nodes["1"] is not None
|
||||
@ -692,6 +700,144 @@ def test_ints_do_not_accept_floats():
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_polymorphic_accepts_single():
|
||||
g = Graph()
|
||||
n1 = StringInvocation(id="1", value="banana")
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_accepts_collection_of_same_base_type():
|
||||
g = Graph()
|
||||
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_does_not_accept_collection_of_different_base_type():
|
||||
g = Graph()
|
||||
n1 = FloatCollectionInvocation(id="1", collection=[1.0, 2.0, 3.0])
|
||||
n2 = PolymorphicStringTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e1)
|
||||
|
||||
|
||||
def test_polymorphic_does_not_accept_generic_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = PolymorphicStringTestInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_any_accepts_integer():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_string():
|
||||
g = Graph()
|
||||
n1 = StringInvocation(id="1", value="banana sundae")
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_generic_collection():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = AnyTypeTestInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_any_accepts_prompt_collection():
|
||||
g = Graph()
|
||||
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "collection", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_any_accepts_any():
|
||||
g = Graph()
|
||||
n1 = AnyTypeTestInvocation(id="1")
|
||||
n2 = AnyTypeTestInvocation(id="2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id, "value", n2.id, "value")
|
||||
# Not throwing on this line is sufficient
|
||||
g.add_edge(e)
|
||||
|
||||
|
||||
def test_iterate_accepts_collection():
|
||||
"""We need to update the validation for Collect -> Iterate to traverse to the Iterate
|
||||
node's output and compare that against the item type of the Collect node's collection. Until
|
||||
then, Collect nodes may not output into Iterate nodes."""
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = IterateInvocation(id="4")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "collection")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# Once we fix the validation logic as described, this should should not raise an error
|
||||
with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
|
@ -1,11 +1,11 @@
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -15,12 +15,12 @@ from invokeai.app.invocations.image import ImageField
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@invocation_output("test_list_output")
|
||||
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[ImageField] = Field(default_factory=list)
|
||||
collection: list[ImageField] = OutputField(default_factory=list)
|
||||
|
||||
|
||||
@invocation("test_list")
|
||||
class ListPassThroughInvocation(BaseInvocation):
|
||||
collection: list[ImageField] = Field(default_factory=list)
|
||||
collection: list[ImageField] = InputField(default_factory=list)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||
@ -28,12 +28,12 @@ class ListPassThroughInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_prompt_output")
|
||||
class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||
prompt: str = Field(default="")
|
||||
prompt: str = OutputField(default="")
|
||||
|
||||
|
||||
@invocation("test_prompt")
|
||||
class PromptTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
prompt: str = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||
@ -47,13 +47,13 @@ class ErrorInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_image_output")
|
||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||
image: ImageField = Field()
|
||||
image: ImageField = OutputField()
|
||||
|
||||
|
||||
@invocation("test_text_to_image")
|
||||
class TextToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
prompt2: str = Field(default="")
|
||||
prompt: str = InputField(default="")
|
||||
prompt2: str = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
@ -61,8 +61,8 @@ class TextToImageTestInvocation(BaseInvocation):
|
||||
|
||||
@invocation("test_image_to_image")
|
||||
class ImageToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = Field(default="")
|
||||
image: Union[ImageField, None] = Field(default=None)
|
||||
prompt: str = InputField(default="")
|
||||
image: Union[ImageField, None] = InputField(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
@ -70,17 +70,40 @@ class ImageToImageTestInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("test_prompt_collection_output")
|
||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[str] = Field(default_factory=list)
|
||||
collection: list[str] = OutputField(default_factory=list)
|
||||
|
||||
|
||||
@invocation("test_prompt_collection")
|
||||
class PromptCollectionTestInvocation(BaseInvocation):
|
||||
collection: list[str] = Field()
|
||||
collection: list[str] = InputField()
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
@invocation_output("test_any_output")
|
||||
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||
value: Any = OutputField()
|
||||
|
||||
|
||||
@invocation("test_any")
|
||||
class AnyTypeTestInvocation(BaseInvocation):
|
||||
value: Any = InputField(default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||
return AnyTypeTestInvocationOutput(value=self.value)
|
||||
|
||||
|
||||
@invocation("test_polymorphic")
|
||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
value: Union[str, list[str]] = InputField(default="")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
if isinstance(self.value, str):
|
||||
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||
|
||||
|
||||
# Importing these must happen after test invocations are defined or they won't register
|
||||
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
||||
|
@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||
assert len(values) == 8
|
||||
|
||||
session_adapter = TypeAdapter(GraphExecutionState)
|
||||
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||
# graph should be serialized
|
||||
ges = session_adapter.validate_json(values[0].session)
|
||||
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
||||
|
||||
# graph values should be populated
|
||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||
@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||
|
||||
# session ids should match deserialized graph
|
||||
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
||||
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
||||
|
||||
# should unique session ids
|
||||
sids = [v.session_id for v in values]
|
||||
assert len(sids) == len(set(sids))
|
||||
|
||||
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
||||
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||
# should have 3 node field values
|
||||
assert type(values[0].field_values) is str
|
||||
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
||||
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
||||
|
||||
# should have batch id and priority
|
||||
assert all(v.batch_id == b.batch_id for v in values)
|
||||
|
Loading…
Reference in New Issue
Block a user