mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ebr/docker-py311
This commit is contained in:
commit
1177234931
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
@ -30,6 +31,7 @@ from ..services.shared.default_graphs import create_system_graphs
|
|||||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.shared.sqlite import SqliteDatabase
|
from ..services.shared.sqlite import SqliteDatabase
|
||||||
from ..services.urls.urls_default import LocalUrlService
|
from ..services.urls.urls_default import LocalUrlService
|
||||||
|
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -90,6 +92,8 @@ class ApiDependencies:
|
|||||||
session_processor = DefaultSessionProcessor()
|
session_processor = DefaultSessionProcessor()
|
||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
|
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
|
||||||
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
board_image_records=board_image_records,
|
board_image_records=board_image_records,
|
||||||
@ -114,6 +118,8 @@ class ApiDependencies:
|
|||||||
session_processor=session_processor,
|
session_processor=session_processor,
|
||||||
session_queue=session_queue,
|
session_queue=session_queue,
|
||||||
urls=urls,
|
urls=urls,
|
||||||
|
workflow_image_records=workflow_image_records,
|
||||||
|
workflow_records=workflow_records,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import io
|
import io
|
||||||
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
@ -45,17 +46,38 @@ async def upload_image(
|
|||||||
if not file.content_type or not file.content_type.startswith("image"):
|
if not file.content_type or not file.content_type.startswith("image"):
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
|
|
||||||
contents = await file.read()
|
metadata = None
|
||||||
|
workflow = None
|
||||||
|
|
||||||
|
contents = await file.read()
|
||||||
try:
|
try:
|
||||||
pil_image = Image.open(io.BytesIO(contents))
|
pil_image = Image.open(io.BytesIO(contents))
|
||||||
if crop_visible:
|
if crop_visible:
|
||||||
bbox = pil_image.getbbox()
|
bbox = pil_image.getbbox()
|
||||||
pil_image = pil_image.crop(bbox)
|
pil_image = pil_image.crop(bbox)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Error opening the image
|
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||||
|
|
||||||
|
# TODO: retain non-invokeai metadata on upload?
|
||||||
|
# attempt to parse metadata from image
|
||||||
|
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
||||||
|
if metadata_raw:
|
||||||
|
try:
|
||||||
|
metadata = MetadataFieldValidator.validate_json(metadata_raw)
|
||||||
|
except ValidationError:
|
||||||
|
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# attempt to parse workflow from image
|
||||||
|
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||||
|
if workflow_raw is not None:
|
||||||
|
try:
|
||||||
|
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
|
||||||
|
except ValidationError:
|
||||||
|
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image_dto = ApiDependencies.invoker.services.images.create(
|
image_dto = ApiDependencies.invoker.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
@ -63,6 +85,8 @@ async def upload_image(
|
|||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
board_id=board_id,
|
board_id=board_id,
|
||||||
|
metadata=metadata,
|
||||||
|
workflow=workflow,
|
||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,6 +95,7 @@ async def upload_image(
|
|||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
except Exception:
|
except Exception:
|
||||||
|
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@ -146,11 +171,11 @@ async def get_image_dto(
|
|||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/i/{image_name}/metadata",
|
"/i/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
response_model=ImageMetadata,
|
response_model=Optional[MetadataField],
|
||||||
)
|
)
|
||||||
async def get_image_metadata(
|
async def get_image_metadata(
|
||||||
image_name: str = Path(description="The name of image to get"),
|
image_name: str = Path(description="The name of image to get"),
|
||||||
) -> ImageMetadata:
|
) -> Optional[MetadataField]:
|
||||||
"""Gets an image's metadata"""
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -23,13 +23,13 @@ from ..dependencies import ApiDependencies
|
|||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
update_models_response_adapter = TypeAdapter(UpdateModelResponse)
|
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
||||||
|
|
||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
import_models_response_adapter = TypeAdapter(ImportModelResponse)
|
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
||||||
|
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
convert_models_response_adapter = TypeAdapter(ConvertModelResponse)
|
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
||||||
|
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
@ -41,7 +41,7 @@ class ModelsList(BaseModel):
|
|||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
models_list_adapter = TypeAdapter(ModelsList)
|
ModelsListValidator = TypeAdapter(ModelsList)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -60,7 +60,7 @@ async def list_models(
|
|||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = models_list_adapter.validate_python({"models": models_raw})
|
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
model_response = update_models_response_adapter.validate_python(model_raw)
|
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -186,7 +186,7 @@ async def import_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||||
)
|
)
|
||||||
return import_models_response_adapter.validate_python(model_raw)
|
return ImportModelResponseValidator.validate_python(model_raw)
|
||||||
|
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@ -231,7 +231,7 @@ async def add_model(
|
|||||||
base_model=info.base_model,
|
base_model=info.base_model,
|
||||||
model_type=info.model_type,
|
model_type=info.model_type,
|
||||||
)
|
)
|
||||||
return import_models_response_adapter.validate_python(model_raw)
|
return ImportModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -302,7 +302,7 @@ async def convert_model(
|
|||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name, base_model=base_model, model_type=model_type
|
model_name, base_model=base_model, model_type=model_type
|
||||||
)
|
)
|
||||||
response = convert_models_response_adapter.validate_python(model_raw)
|
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -417,7 +417,7 @@ async def merge_models(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Main,
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
response = convert_models_response_adapter.validate_python(model_raw)
|
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
|
20
invokeai/app/api/routers/workflows.py
Normal file
20
invokeai/app/api/routers/workflows.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from fastapi import APIRouter, Path
|
||||||
|
|
||||||
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||||
|
|
||||||
|
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||||
|
|
||||||
|
|
||||||
|
@workflows_router.get(
|
||||||
|
"/i/{workflow_id}",
|
||||||
|
operation_id="get_workflow",
|
||||||
|
responses={
|
||||||
|
200: {"model": WorkflowField},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_workflow(
|
||||||
|
workflow_id: str = Path(description="The workflow to get"),
|
||||||
|
) -> WorkflowField:
|
||||||
|
"""Gets a workflow"""
|
||||||
|
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
@ -38,7 +38,17 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
from .api.routers import (
|
||||||
|
app_info,
|
||||||
|
board_images,
|
||||||
|
boards,
|
||||||
|
images,
|
||||||
|
models,
|
||||||
|
session_queue,
|
||||||
|
sessions,
|
||||||
|
utilities,
|
||||||
|
workflows,
|
||||||
|
)
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||||
|
|
||||||
@ -95,18 +105,13 @@ async def shutdown_event() -> None:
|
|||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(utilities.utilities_router, prefix="/api")
|
app.include_router(utilities.utilities_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(boards.boards_router, prefix="/api")
|
app.include_router(boards.boards_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(board_images.board_images_router, prefix="/api")
|
app.include_router(board_images.board_images_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(app_info.app_router, prefix="/api")
|
app.include_router(app_info.app_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
@ -166,7 +171,6 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
# print(f"Config with name {name} already defined")
|
# print(f"Config with name {name} already defined")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
|
||||||
openapi_schema["components"]["schemas"][name] = dict(
|
openapi_schema["components"]["schemas"][name] = dict(
|
||||||
title=name,
|
title=name,
|
||||||
description="An enumeration.",
|
description="An enumeration.",
|
||||||
|
@ -1,8 +1,28 @@
|
|||||||
import os
|
import shutil
|
||||||
|
import sys
|
||||||
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
__all__ = []
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
|
||||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.absolute())
|
||||||
for f in os.listdir(dirname):
|
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
|
||||||
__all__.append(f[:-3])
|
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||||
|
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
|
||||||
|
|
||||||
|
# copy our custom nodes __init__.py to the custom nodes directory
|
||||||
|
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
|
||||||
|
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
|
||||||
|
|
||||||
|
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||||
|
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
|
||||||
|
module = module_from_spec(spec)
|
||||||
|
sys.modules[spec.name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# add core nodes to __all__
|
||||||
|
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||||
|
__all__ = list(f.stem for f in python_files) # type: ignore
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import inspect
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -11,8 +11,8 @@ from types import UnionType
|
|||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import FieldInfo, _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
@ -26,6 +26,10 @@ class InvalidVersionError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFieldError(TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FieldDescriptions:
|
class FieldDescriptions:
|
||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
@ -60,7 +64,12 @@ class FieldDescriptions:
|
|||||||
denoised_latents = "Denoised latents tensor"
|
denoised_latents = "Denoised latents tensor"
|
||||||
latents = "Latents tensor"
|
latents = "Latents tensor"
|
||||||
strength = "Strength of denoising (proportional to steps)"
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
core_metadata = "Optional core metadata to be written to image"
|
metadata = "Optional metadata to be saved with the image"
|
||||||
|
metadata_collection = "Collection of Metadata"
|
||||||
|
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||||
|
metadata_item_label = "Label for this metadata item"
|
||||||
|
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||||
|
workflow = "Optional workflow to be saved with the image"
|
||||||
interp_mode = "Interpolation mode"
|
interp_mode = "Interpolation mode"
|
||||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||||
fp32 = "Whether or not to use full float32 precision"
|
fp32 = "Whether or not to use full float32 precision"
|
||||||
@ -167,8 +176,12 @@ class UIType(str, Enum):
|
|||||||
Scheduler = "Scheduler"
|
Scheduler = "Scheduler"
|
||||||
WorkflowField = "WorkflowField"
|
WorkflowField = "WorkflowField"
|
||||||
IsIntermediate = "IsIntermediate"
|
IsIntermediate = "IsIntermediate"
|
||||||
MetadataField = "MetadataField"
|
|
||||||
BoardField = "BoardField"
|
BoardField = "BoardField"
|
||||||
|
Any = "Any"
|
||||||
|
MetadataItem = "MetadataItem"
|
||||||
|
MetadataItemCollection = "MetadataItemCollection"
|
||||||
|
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||||
|
MetadataDict = "MetadataDict"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
@ -294,6 +307,7 @@ def InputField(
|
|||||||
ui_order=ui_order,
|
ui_order=ui_order,
|
||||||
item_default=item_default,
|
item_default=item_default,
|
||||||
ui_choice_labels=ui_choice_labels,
|
ui_choice_labels=ui_choice_labels,
|
||||||
|
_field_kind="input",
|
||||||
)
|
)
|
||||||
|
|
||||||
field_args = dict(
|
field_args = dict(
|
||||||
@ -436,6 +450,7 @@ def OutputField(
|
|||||||
ui_type=ui_type,
|
ui_type=ui_type,
|
||||||
ui_hidden=ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
ui_order=ui_order,
|
ui_order=ui_order,
|
||||||
|
_field_kind="output",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -519,6 +534,7 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
@ -541,9 +557,6 @@ class MissingInputException(Exception):
|
|||||||
|
|
||||||
class BaseInvocation(ABC, BaseModel):
|
class BaseInvocation(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
A node to process inputs and produce outputs.
|
|
||||||
May use dependency injection in __init__ to receive providers.
|
|
||||||
|
|
||||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -659,37 +672,21 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=uuid_string,
|
default_factory=uuid_string,
|
||||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||||
|
json_schema_extra=dict(_field_kind="internal"),
|
||||||
)
|
)
|
||||||
is_intermediate: Optional[bool] = Field(
|
is_intermediate: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether or not this is an intermediate invocation.",
|
description="Whether or not this is an intermediate invocation.",
|
||||||
json_schema_extra=dict(ui_type=UIType.IsIntermediate),
|
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
|
||||||
)
|
)
|
||||||
workflow: Optional[str] = Field(
|
use_cache: bool = Field(
|
||||||
default=None,
|
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
|
||||||
description="The workflow to save with the image",
|
|
||||||
json_schema_extra=dict(ui_type=UIType.WorkflowField),
|
|
||||||
)
|
)
|
||||||
use_cache: Optional[bool] = Field(
|
|
||||||
default=True,
|
|
||||||
description="Whether or not to use the cache",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("workflow", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_workflow_is_json(cls, v):
|
|
||||||
"""We don't have a workflow schema in the backend, so we just check that it's valid JSON"""
|
|
||||||
if v is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
json.loads(v)
|
|
||||||
except json.decoder.JSONDecodeError:
|
|
||||||
raise ValueError("Workflow must be valid JSON")
|
|
||||||
return v
|
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=True,
|
||||||
@ -700,6 +697,68 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||||
|
|
||||||
|
|
||||||
|
RESERVED_INPUT_FIELD_NAMES = {
|
||||||
|
"id",
|
||||||
|
"is_intermediate",
|
||||||
|
"use_cache",
|
||||||
|
"type",
|
||||||
|
"workflow",
|
||||||
|
"metadata",
|
||||||
|
}
|
||||||
|
|
||||||
|
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||||
|
|
||||||
|
|
||||||
|
class _Model(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Get all pydantic model attrs, methods, etc
|
||||||
|
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||||
|
"""
|
||||||
|
Validates the fields of an invocation or invocation output:
|
||||||
|
- must not override any pydantic reserved fields
|
||||||
|
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||||
|
"""
|
||||||
|
for name, field in model_fields.items():
|
||||||
|
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||||
|
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||||
|
|
||||||
|
field_kind = (
|
||||||
|
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||||
|
field.json_schema_extra.get("_field_kind", None)
|
||||||
|
if field.json_schema_extra
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# must have a field_kind
|
||||||
|
if field_kind is None or field_kind not in {"input", "output", "internal"}:
|
||||||
|
raise InvalidFieldError(
|
||||||
|
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||||
|
)
|
||||||
|
|
||||||
|
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
|
||||||
|
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||||
|
|
||||||
|
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||||
|
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||||
|
|
||||||
|
# internal fields *must* be in the reserved list
|
||||||
|
if (
|
||||||
|
field_kind == "internal"
|
||||||
|
and name not in RESERVED_INPUT_FIELD_NAMES
|
||||||
|
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||||
|
):
|
||||||
|
raise InvalidFieldError(
|
||||||
|
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def invocation(
|
def invocation(
|
||||||
invocation_type: str,
|
invocation_type: str,
|
||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
@ -709,7 +768,7 @@ def invocation(
|
|||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation.
|
Registers an invocation.
|
||||||
|
|
||||||
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
||||||
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||||
@ -728,6 +787,8 @@ def invocation(
|
|||||||
if invocation_type in BaseInvocation.get_invocation_types():
|
if invocation_type in BaseInvocation.get_invocation_types():
|
||||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||||
|
|
||||||
|
validate_fields(cls.model_fields, invocation_type)
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
@ -758,8 +819,7 @@ def invocation(
|
|||||||
|
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
invocation_type_field = Field(
|
invocation_type_field = Field(
|
||||||
title="type",
|
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
|
||||||
default=invocation_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
@ -800,13 +860,12 @@ def invocation_output(
|
|||||||
if output_type in BaseInvocationOutput.get_output_types():
|
if output_type in BaseInvocationOutput.get_output_types():
|
||||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||||
|
|
||||||
|
validate_fields(cls.model_fields, output_type)
|
||||||
|
|
||||||
# Add the output type to the model.
|
# Add the output type to the model.
|
||||||
|
|
||||||
output_type_annotation = Literal[output_type] # type: ignore
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
output_type_field = Field(
|
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
|
||||||
title="type",
|
|
||||||
default=output_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
docstring = cls.__doc__
|
docstring = cls.__doc__
|
||||||
cls = create_model(
|
cls = create_model(
|
||||||
@ -824,4 +883,37 @@ def invocation_output(
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
class WorkflowField(RootModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||||
|
Workflows are stored without a strict schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: dict[str, Any] = Field(description="The workflow")
|
||||||
|
|
||||||
|
|
||||||
|
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||||
|
|
||||||
|
|
||||||
|
class WithWorkflow(BaseModel):
|
||||||
|
workflow: Optional[WorkflowField] = Field(
|
||||||
|
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataField(RootModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
|
Metadata is stored without a strict schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: dict[str, Any] = Field(description="The metadata")
|
||||||
|
|
||||||
|
|
||||||
|
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||||
|
|
||||||
|
|
||||||
|
class WithMetadata(BaseModel):
|
||||||
|
metadata: Optional[MetadataField] = Field(
|
||||||
|
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
|
||||||
|
)
|
||||||
|
@ -38,6 +38,8 @@ from .baseinvocation import (
|
|||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
|
WithMetadata,
|
||||||
|
WithWorkflow,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -127,12 +129,12 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||||
class ImageProcessorInvocation(BaseInvocation):
|
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -150,6 +152,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
51
invokeai/app/invocations/custom_nodes/README.md
Normal file
51
invokeai/app/invocations/custom_nodes/README.md
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Custom Nodes / Node Packs
|
||||||
|
|
||||||
|
Copy your node packs to this directory.
|
||||||
|
|
||||||
|
When nodes are added or changed, you must restart the app to see the changes.
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
For a node pack to be loaded, it must be placed in a directory alongside this
|
||||||
|
file. Here's an example structure:
|
||||||
|
|
||||||
|
```py
|
||||||
|
.
|
||||||
|
├── __init__.py # Invoke-managed custom node loader
|
||||||
|
│
|
||||||
|
├── cool_node
|
||||||
|
│ ├── __init__.py # see example below
|
||||||
|
│ └── cool_node.py
|
||||||
|
│
|
||||||
|
└── my_node_pack
|
||||||
|
├── __init__.py # see example below
|
||||||
|
├── tasty_node.py
|
||||||
|
├── bodacious_node.py
|
||||||
|
├── utils.py
|
||||||
|
└── extra_nodes
|
||||||
|
└── fancy_node.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Node Pack `__init__.py`
|
||||||
|
|
||||||
|
Each node pack must have an `__init__.py` file that imports its nodes.
|
||||||
|
|
||||||
|
The structure of each node or node pack is otherwise not important.
|
||||||
|
|
||||||
|
Here are examples, based on the example directory structure.
|
||||||
|
|
||||||
|
### `cool_node/__init__.py`
|
||||||
|
|
||||||
|
```py
|
||||||
|
from .cool_node import CoolInvocation
|
||||||
|
```
|
||||||
|
|
||||||
|
### `my_node_pack/__init__.py`
|
||||||
|
|
||||||
|
```py
|
||||||
|
from .tasty_node import TastyInvocation
|
||||||
|
from .bodacious_node import BodaciousInvocation
|
||||||
|
from .extra_nodes.fancy_node import FancyInvocation
|
||||||
|
```
|
||||||
|
|
||||||
|
Only nodes imported in the `__init__.py` file are loaded.
|
51
invokeai/app/invocations/custom_nodes/init.py
Normal file
51
invokeai/app/invocations/custom_nodes/init.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
Invoke-managed custom node loader. See README.md for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
logger = InvokeAILogger.get_logger()
|
||||||
|
loaded_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
for d in Path(__file__).parent.iterdir():
|
||||||
|
# skip files
|
||||||
|
if not d.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip hidden directories
|
||||||
|
if d.name.startswith("_") or d.name.startswith("."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip directories without an `__init__.py`
|
||||||
|
init = d / "__init__.py"
|
||||||
|
if not init.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_name = init.parent.stem
|
||||||
|
|
||||||
|
# skip if already imported
|
||||||
|
if module_name in globals():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# we have a legit module to import
|
||||||
|
spec = spec_from_file_location(module_name, init.absolute())
|
||||||
|
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
logger.warn(f"Could not load {init}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
module = module_from_spec(spec)
|
||||||
|
sys.modules[spec.name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
loaded_count += 1
|
||||||
|
|
||||||
|
del init, module_name
|
||||||
|
|
||||||
|
|
||||||
|
logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}")
|
@ -8,11 +8,11 @@ from PIL import Image, ImageOps
|
|||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class CvInpaintInvocation(BaseInvocation):
|
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to inpaint")
|
image: ImageField = InputField(description="The image to inpaint")
|
||||||
|
@ -16,6 +16,8 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
|
WithMetadata,
|
||||||
|
WithWorkflow,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -437,7 +439,7 @@ def get_faces_list(
|
|||||||
|
|
||||||
|
|
||||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.2")
|
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.2")
|
||||||
class FaceOffInvocation(BaseInvocation):
|
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="Image for face detection")
|
image: ImageField = InputField(description="Image for face detection")
|
||||||
@ -531,7 +533,7 @@ class FaceOffInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.2")
|
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.2")
|
||||||
class FaceMaskInvocation(BaseInvocation):
|
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Face mask creation using mediapipe face detection"""
|
"""Face mask creation using mediapipe face detection"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="Image to face detect")
|
image: ImageField = InputField(description="Image to face detect")
|
||||||
@ -650,7 +652,7 @@ class FaceMaskInvocation(BaseInvocation):
|
|||||||
@invocation(
|
@invocation(
|
||||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.2"
|
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.2"
|
||||||
)
|
)
|
||||||
class FaceIdentifierInvocation(BaseInvocation):
|
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="Image to face detect")
|
image: ImageField = InputField(description="Image to face detect")
|
||||||
|
@ -7,13 +7,21 @@ import cv2
|
|||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
|
||||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
WithMetadata,
|
||||||
|
WithWorkflow,
|
||||||
|
invocation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||||
@ -36,14 +44,8 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||||
"blank_image",
|
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
title="Blank Image",
|
|
||||||
tags=["image"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class BlankImageInvocation(BaseInvocation):
|
|
||||||
"""Creates a blank image and forwards it to the pipeline"""
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
width: int = InputField(default=512, description="The width of the image")
|
width: int = InputField(default=512, description="The width of the image")
|
||||||
@ -61,6 +63,7 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,14 +74,8 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||||
"img_crop",
|
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Crop Image",
|
|
||||||
tags=["image", "crop"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageCropInvocation(BaseInvocation):
|
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to crop")
|
image: ImageField = InputField(description="The image to crop")
|
||||||
@ -100,6 +97,7 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -110,14 +108,8 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
|
||||||
"img_paste",
|
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Paste Image",
|
|
||||||
tags=["image", "paste"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.1",
|
|
||||||
)
|
|
||||||
class ImagePasteInvocation(BaseInvocation):
|
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
base_image: ImageField = InputField(description="The base image")
|
base_image: ImageField = InputField(description="The base image")
|
||||||
@ -159,6 +151,7 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -169,14 +162,8 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||||
"tomask",
|
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Mask from Alpha",
|
|
||||||
tags=["image", "mask"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class MaskFromAlphaInvocation(BaseInvocation):
|
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to create the mask from")
|
image: ImageField = InputField(description="The image to create the mask from")
|
||||||
@ -196,6 +183,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -206,14 +194,8 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||||
"img_mul",
|
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Multiply Images",
|
|
||||||
tags=["image", "multiply"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageMultiplyInvocation(BaseInvocation):
|
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
image1: ImageField = InputField(description="The first image to multiply")
|
image1: ImageField = InputField(description="The first image to multiply")
|
||||||
@ -232,6 +214,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -245,14 +228,8 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||||
"img_chan",
|
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Extract Image Channel",
|
|
||||||
tags=["image", "channel"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageChannelInvocation(BaseInvocation):
|
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to get the channel from")
|
image: ImageField = InputField(description="The image to get the channel from")
|
||||||
@ -270,6 +247,7 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -283,14 +261,8 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||||
"img_conv",
|
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Convert Image Mode",
|
|
||||||
tags=["image", "convert"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageConvertInvocation(BaseInvocation):
|
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to convert")
|
image: ImageField = InputField(description="The image to convert")
|
||||||
@ -308,6 +280,7 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -318,14 +291,8 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||||
"img_blur",
|
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Blur Image",
|
|
||||||
tags=["image", "blur"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageBlurInvocation(BaseInvocation):
|
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to blur")
|
image: ImageField = InputField(description="The image to blur")
|
||||||
@ -348,6 +315,7 @@ class ImageBlurInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -378,23 +346,14 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||||
"img_resize",
|
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
title="Resize Image",
|
|
||||||
tags=["image", "resize"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageResizeInvocation(BaseInvocation):
|
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to resize")
|
image: ImageField = InputField(description="The image to resize")
|
||||||
width: int = InputField(default=512, gt=0, description="The width to resize to (px)")
|
width: int = InputField(default=512, gt=0, description="The width to resize to (px)")
|
||||||
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
|
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
|
||||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -413,7 +372,7 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -424,14 +383,8 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||||
"img_scale",
|
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
title="Scale Image",
|
|
||||||
tags=["image", "scale"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageScaleInvocation(BaseInvocation):
|
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to scale")
|
image: ImageField = InputField(description="The image to scale")
|
||||||
@ -461,6 +414,7 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -471,14 +425,8 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||||
"img_lerp",
|
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Lerp Image",
|
|
||||||
tags=["image", "lerp"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageLerpInvocation(BaseInvocation):
|
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
@ -500,6 +448,7 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -510,14 +459,8 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||||
"img_ilerp",
|
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Inverse Lerp Image",
|
|
||||||
tags=["image", "ilerp"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageInverseLerpInvocation(BaseInvocation):
|
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
@ -539,6 +482,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -549,20 +493,11 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||||
"img_nsfw",
|
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
title="Blur NSFW Image",
|
|
||||||
tags=["image", "nsfw"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
|
||||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -583,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -607,14 +542,11 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
category="image",
|
category="image",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageWatermarkInvocation(BaseInvocation):
|
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
|
||||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -626,7 +558,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -637,14 +569,8 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||||
"mask_edge",
|
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Mask Edge",
|
|
||||||
tags=["image", "mask", "inpaint"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class MaskEdgeInvocation(BaseInvocation):
|
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to apply the mask to")
|
image: ImageField = InputField(description="The image to apply the mask to")
|
||||||
@ -678,6 +604,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -695,7 +622,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
category="image",
|
category="image",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MaskCombineInvocation(BaseInvocation):
|
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
mask1: ImageField = InputField(description="The first mask to combine")
|
mask1: ImageField = InputField(description="The first mask to combine")
|
||||||
@ -714,6 +641,7 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -724,14 +652,8 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||||
"color_correct",
|
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Color Correct",
|
|
||||||
tags=["image", "color"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ColorCorrectInvocation(BaseInvocation):
|
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
using a mask to only color-correct certain regions of the target image.
|
using a mask to only color-correct certain regions of the target image.
|
||||||
@ -830,6 +752,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -840,14 +763,8 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||||
"img_hue_adjust",
|
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
title="Adjust Image Hue",
|
|
||||||
tags=["image", "hue"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
@ -875,6 +792,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -950,7 +868,7 @@ CHANNEL_FORMATS = {
|
|||||||
category="image",
|
category="image",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageChannelOffsetInvocation(BaseInvocation):
|
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Add or subtract a value from a specific color channel of an image."""
|
"""Add or subtract a value from a specific color channel of an image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
@ -984,6 +902,7 @@ class ImageChannelOffsetInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1020,7 +939,7 @@ class ImageChannelOffsetInvocation(BaseInvocation):
|
|||||||
category="image",
|
category="image",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageChannelMultiplyInvocation(BaseInvocation):
|
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Scale a specific color channel of an image."""
|
"""Scale a specific color channel of an image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
@ -1060,6 +979,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
|||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
|
metadata=self.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -1079,16 +999,11 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
|||||||
version="1.0.1",
|
version="1.0.1",
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
class SaveImageInvocation(BaseInvocation):
|
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||||
|
|
||||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.core_metadata,
|
|
||||||
ui_hidden=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -1101,7 +1016,7 @@ class SaveImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
|||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class InfillColorInvocation(BaseInvocation):
|
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -143,6 +143,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -154,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class InfillTileInvocation(BaseInvocation):
|
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -179,6 +180,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -192,7 +194,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
@invocation(
|
@invocation(
|
||||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||||
)
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation):
|
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -232,6 +234,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,7 +246,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class LaMaInfillInvocation(BaseInvocation):
|
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -260,6 +263,8 @@ class LaMaInfillInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -269,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||||
class CV2InfillInvocation(BaseInvocation):
|
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -287,6 +292,8 @@ class CV2InfillInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -23,7 +23,6 @@ from pydantic import field_validator
|
|||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import (
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
DenoiseMaskOutput,
|
DenoiseMaskOutput,
|
||||||
@ -64,6 +63,8 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
|
WithMetadata,
|
||||||
|
WithWorkflow,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -792,7 +793,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
@ -805,11 +806,6 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.core_metadata,
|
|
||||||
ui_hidden=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -878,7 +874,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
from typing import Optional
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
MetadataField,
|
||||||
OutputField,
|
OutputField,
|
||||||
|
UIType,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -16,116 +19,99 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
|||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModelExcludeNull):
|
class MetadataItemField(BaseModel):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
label: str = Field(description=FieldDescriptions.metadata_item_label)
|
||||||
|
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||||
lora: LoRAModelField = Field(description="The LoRA model")
|
|
||||||
weight: float = Field(description="The weight of the LoRA model")
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterMetadataField(BaseModelExcludeNull):
|
class LoRAMetadataField(BaseModel):
|
||||||
|
"""LoRA Metadata Field"""
|
||||||
|
|
||||||
|
lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||||
|
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterMetadataField(BaseModel):
|
||||||
|
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||||
|
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: IPAdapterModelField = Field(
|
||||||
weight: float = Field(description="The weight of the IP-Adapter model")
|
description="The IP-Adapter model.",
|
||||||
begin_step_percent: float = Field(
|
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
|
||||||
)
|
)
|
||||||
end_step_percent: float = Field(
|
weight: Union[float, list[float]] = Field(
|
||||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
description="The weight given to the IP-Adapter",
|
||||||
|
)
|
||||||
|
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||||
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("metadata_item_output")
|
||||||
|
class MetadataItemOutput(BaseInvocationOutput):
|
||||||
|
"""Metadata Item Output"""
|
||||||
|
|
||||||
|
item: MetadataItemField = OutputField(description="Metadata Item")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.0")
|
||||||
|
class MetadataItemInvocation(BaseInvocation):
|
||||||
|
"""Used to create an arbitrary metadata item. Provide "label" and make a connection to "value" to store that data as the value."""
|
||||||
|
|
||||||
|
label: str = InputField(description=FieldDescriptions.metadata_item_label)
|
||||||
|
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MetadataItemOutput:
|
||||||
|
return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value))
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("metadata_output")
|
||||||
|
class MetadataOutput(BaseInvocationOutput):
|
||||||
|
metadata: MetadataField = OutputField(description="Metadata Dict")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||||
|
class MetadataInvocation(BaseInvocation):
|
||||||
|
"""Takes a MetadataItem or collection of MetadataItems and outputs a MetadataDict."""
|
||||||
|
|
||||||
|
items: Union[list[MetadataItemField], MetadataItemField] = InputField(
|
||||||
|
description=FieldDescriptions.metadata_item_polymorphic
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||||
|
if isinstance(self.items, MetadataItemField):
|
||||||
|
# single metadata item
|
||||||
|
data = {self.items.label: self.items.value}
|
||||||
|
else:
|
||||||
|
# collection of metadata items
|
||||||
|
data = {item.label: item.value for item in self.items}
|
||||||
|
|
||||||
class CoreMetadata(BaseModelExcludeNull):
|
# add app version
|
||||||
"""Core generation metadata for an image generated in InvokeAI."""
|
data.update({"app_version": __version__})
|
||||||
|
return MetadataOutput(metadata=MetadataField.model_validate(data))
|
||||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
|
||||||
generation_mode: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The generation mode that output this image",
|
|
||||||
)
|
|
||||||
created_by: Optional[str] = Field(default=None, description="The name of the creator of the image")
|
|
||||||
positive_prompt: Optional[str] = Field(default=None, description="The positive prompt parameter")
|
|
||||||
negative_prompt: Optional[str] = Field(default=None, description="The negative prompt parameter")
|
|
||||||
width: Optional[int] = Field(default=None, description="The width parameter")
|
|
||||||
height: Optional[int] = Field(default=None, description="The height parameter")
|
|
||||||
seed: Optional[int] = Field(default=None, description="The seed used for noise generation")
|
|
||||||
rand_device: Optional[str] = Field(default=None, description="The device used for random number generation")
|
|
||||||
cfg_scale: Optional[float] = Field(default=None, description="The classifier-free guidance scale parameter")
|
|
||||||
steps: Optional[int] = Field(default=None, description="The number of steps used for inference")
|
|
||||||
scheduler: Optional[str] = Field(default=None, description="The scheduler used for inference")
|
|
||||||
clip_skip: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The number of skipped CLIP layers",
|
|
||||||
)
|
|
||||||
model: Optional[MainModelField] = Field(default=None, description="The main model used for inference")
|
|
||||||
controlnets: Optional[list[ControlField]] = Field(default=None, description="The ControlNets used for inference")
|
|
||||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = Field(
|
|
||||||
default=None, description="The IP Adapters used for inference"
|
|
||||||
)
|
|
||||||
t2iAdapters: Optional[list[T2IAdapterField]] = Field(default=None, description="The IP Adapters used for inference")
|
|
||||||
loras: Optional[list[LoRAMetadataField]] = Field(default=None, description="The LoRAs used for inference")
|
|
||||||
vae: Optional[VAEModelField] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Latents-to-Latents
|
|
||||||
strength: Optional[float] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The strength used for latents-to-latents",
|
|
||||||
)
|
|
||||||
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
|
||||||
|
|
||||||
# SDXL
|
|
||||||
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
|
||||||
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
|
||||||
|
|
||||||
# SDXL Refiner
|
|
||||||
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
|
||||||
refiner_cfg_scale: Optional[float] = Field(
|
|
||||||
default=None,
|
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
|
||||||
)
|
|
||||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
|
||||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
|
||||||
refiner_positive_aesthetic_score: Optional[float] = Field(
|
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
|
||||||
)
|
|
||||||
refiner_negative_aesthetic_score: Optional[float] = Field(
|
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
|
||||||
)
|
|
||||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModelExcludeNull):
|
@invocation("merge_metadata", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.0")
|
||||||
"""An image's generation metadata"""
|
class MergeMetadataInvocation(BaseInvocation):
|
||||||
|
"""Merged a collection of MetadataDict into a single MetadataDict."""
|
||||||
|
|
||||||
metadata: Optional[dict] = Field(
|
collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection)
|
||||||
default=None,
|
|
||||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||||
)
|
data = {}
|
||||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
for item in self.collection:
|
||||||
|
data.update(item.model_dump())
|
||||||
|
|
||||||
|
return MetadataOutput(metadata=MetadataField.model_validate(data))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("metadata_accumulator_output")
|
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
class CoreMetadataInvocation(BaseInvocation):
|
||||||
"""The output of the MetadataAccumulator node"""
|
"""Collects core generation metadata into a MetadataField"""
|
||||||
|
|
||||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
generation_mode: Literal["txt2img", "img2img", "inpaint", "outpaint"] = InputField(
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
|
||||||
)
|
|
||||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
|
||||||
"""Outputs a Core Metadata Object"""
|
|
||||||
|
|
||||||
generation_mode: Optional[str] = InputField(
|
|
||||||
default=None,
|
default=None,
|
||||||
description="The generation mode that output this image",
|
description="The generation mode that output this image",
|
||||||
)
|
)
|
||||||
@ -138,6 +124,8 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||||
|
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
||||||
|
seamless_y: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the Y axis")
|
||||||
clip_skip: Optional[int] = InputField(
|
clip_skip: Optional[int] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
@ -220,7 +208,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
description="The start value used for refiner denoising",
|
description="The start value used for refiner denoising",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.model_dump()))
|
return MetadataOutput(
|
||||||
|
metadata=MetadataField.model_validate(
|
||||||
|
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
@ -4,7 +4,7 @@ import inspect
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
# from contextlib import ExitStack
|
# from contextlib import ExitStack
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,7 +12,6 @@ from diffusers.image_processor import VaeImageProcessor
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
@ -31,6 +30,8 @@ from .baseinvocation import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
UIType,
|
UIType,
|
||||||
|
WithMetadata,
|
||||||
|
WithWorkflow,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -327,7 +328,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
category="image",
|
category="image",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
@ -338,11 +339,6 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.vae,
|
description=FieldDescriptions.vae,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.core_metadata,
|
|
||||||
ui_hidden=True,
|
|
||||||
)
|
|
||||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -381,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -251,7 +251,9 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
|
|
||||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||||
class ImageInvocation(BaseInvocation):
|
class ImageInvocation(
|
||||||
|
BaseInvocation,
|
||||||
|
):
|
||||||
"""An image primitive value"""
|
"""An image primitive value"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to load")
|
image: ImageField = InputField(description="The image to load")
|
||||||
|
@ -14,7 +14,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
|||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
# TODO: Use model manager to load?
|
# TODO: Use model manager to load?
|
||||||
@ -30,7 +30,7 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
@ -123,6 +123,7 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
workflow=self.workflow,
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -243,6 +243,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||||
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||||
|
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
@ -410,6 +411,13 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""
|
"""
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_nodes_path(self) -> Path:
|
||||||
|
"""
|
||||||
|
Path to the custom nodes directory
|
||||||
|
"""
|
||||||
|
return self._resolve(self.custom_nodes_dir)
|
||||||
|
|
||||||
# the following methods support legacy calls leftover from the Globals era
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
@property
|
@property
|
||||||
def full_precision(self) -> bool:
|
def full_precision(self) -> bool:
|
||||||
|
@ -4,6 +4,8 @@ from typing import Optional
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||||
|
|
||||||
|
|
||||||
class ImageFileStorageBase(ABC):
|
class ImageFileStorageBase(ABC):
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
@ -30,8 +32,8 @@ class ImageFileStorageBase(ABC):
|
|||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
workflow: Optional[str] = None,
|
workflow: Optional[WorkflowField] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@ -8,6 +7,7 @@ from PIL import Image, PngImagePlugin
|
|||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
@ -55,8 +55,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
workflow: Optional[str] = None,
|
workflow: Optional[WorkflowField] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@ -65,20 +65,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
|
||||||
if metadata is not None or workflow is not None:
|
if metadata is not None:
|
||||||
if metadata is not None:
|
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json())
|
||||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
if workflow is not None:
|
||||||
if workflow is not None:
|
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json())
|
||||||
pnginfo.add_text("invokeai_workflow", workflow)
|
|
||||||
else:
|
|
||||||
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
|
|
||||||
# TODO: retain non-invokeai metadata on save...
|
|
||||||
original_metadata = image.info.get("invokeai_metadata", None)
|
|
||||||
if original_metadata is not None:
|
|
||||||
pnginfo.add_text("invokeai_metadata", original_metadata)
|
|
||||||
original_workflow = image.info.get("invokeai_workflow", None)
|
|
||||||
if original_workflow is not None:
|
|
||||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
|
||||||
|
|
||||||
image.save(
|
image.save(
|
||||||
image_path,
|
image_path,
|
||||||
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import MetadataField
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||||
@ -18,7 +19,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
"""Gets an image's metadata'."""
|
"""Gets an image's metadata'."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -78,7 +79,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
starred: Optional[bool] = False,
|
starred: Optional[bool] = False,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import json
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
@ -141,22 +141,26 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
return deserialize_image_record(dict(result))
|
return deserialize_image_record(dict(result))
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT images.metadata FROM images
|
SELECT metadata FROM images
|
||||||
WHERE image_name = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
(image_name,),
|
(image_name,),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||||
if not result or not result[0]:
|
|
||||||
return None
|
if not result:
|
||||||
return json.loads(result[0])
|
raise ImageRecordNotFoundException
|
||||||
|
|
||||||
|
as_dict = dict(result)
|
||||||
|
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||||
|
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise ImageRecordNotFoundException from e
|
raise ImageRecordNotFoundException from e
|
||||||
@ -408,10 +412,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
starred: Optional[bool] = False,
|
starred: Optional[bool] = False,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
metadata_json = metadata.model_dump_json() if metadata is not None else None
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
|
@ -3,7 +3,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||||
from invokeai.app.services.image_records.image_records_common import (
|
from invokeai.app.services.image_records.image_records_common import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
@ -50,8 +50,8 @@ class ImageServiceABC(ABC):
|
|||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
is_intermediate: Optional[bool] = False,
|
is_intermediate: Optional[bool] = False,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
workflow: Optional[str] = None,
|
workflow: Optional[WorkflowField] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
@ -81,7 +81,7 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
"""Gets an image's metadata."""
|
"""Gets an image's metadata."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -24,8 +24,11 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
|||||||
default=None, description="The id of the board the image belongs to, if one exists."
|
default=None, description="The id of the board the image belongs to, if one exists."
|
||||||
)
|
)
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
workflow_id: Optional[str] = Field(
|
||||||
pass
|
default=None,
|
||||||
|
description="The workflow that generated this image.",
|
||||||
|
)
|
||||||
|
"""The workflow that generated this image."""
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
def image_record_to_dto(
|
||||||
@ -33,6 +36,7 @@ def image_record_to_dto(
|
|||||||
image_url: str,
|
image_url: str,
|
||||||
thumbnail_url: str,
|
thumbnail_url: str,
|
||||||
board_id: Optional[str],
|
board_id: Optional[str],
|
||||||
|
workflow_id: Optional[str],
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Converts an image record to an image DTO."""
|
"""Converts an image record to an image DTO."""
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
@ -40,4 +44,5 @@ def image_record_to_dto(
|
|||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
board_id=board_id,
|
board_id=board_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
)
|
)
|
||||||
|
@ -2,10 +2,9 @@ from typing import Optional
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
|
||||||
|
|
||||||
from ..image_files.image_files_common import (
|
from ..image_files.image_files_common import (
|
||||||
ImageFileDeleteException,
|
ImageFileDeleteException,
|
||||||
@ -42,8 +41,8 @@ class ImageService(ImageServiceABC):
|
|||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
is_intermediate: Optional[bool] = False,
|
is_intermediate: Optional[bool] = False,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[MetadataField] = None,
|
||||||
workflow: Optional[str] = None,
|
workflow: Optional[WorkflowField] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
if image_origin not in ResourceOrigin:
|
if image_origin not in ResourceOrigin:
|
||||||
raise InvalidOriginException
|
raise InvalidOriginException
|
||||||
@ -56,6 +55,12 @@ class ImageService(ImageServiceABC):
|
|||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if workflow is not None:
|
||||||
|
created_workflow = self.__invoker.services.workflow_records.create(workflow)
|
||||||
|
workflow_id = created_workflow.model_dump()["id"]
|
||||||
|
else:
|
||||||
|
workflow_id = None
|
||||||
|
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
self.__invoker.services.image_records.save(
|
self.__invoker.services.image_records.save(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
@ -73,6 +78,8 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
|
if workflow_id is not None:
|
||||||
|
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
|
||||||
self.__invoker.services.image_files.save(
|
self.__invoker.services.image_files.save(
|
||||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||||
)
|
)
|
||||||
@ -132,10 +139,11 @@ class ImageService(ImageServiceABC):
|
|||||||
image_record = self.__invoker.services.image_records.get(image_name)
|
image_record = self.__invoker.services.image_records.get(image_name)
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
image_dto = image_record_to_dto(
|
||||||
image_record,
|
image_record=image_record,
|
||||||
self.__invoker.services.urls.get_image_url(image_name),
|
image_url=self.__invoker.services.urls.get_image_url(image_name),
|
||||||
self.__invoker.services.urls.get_image_url(image_name, True),
|
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
|
||||||
self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||||
|
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@ -146,25 +154,22 @@ class ImageService(ImageServiceABC):
|
|||||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
try:
|
try:
|
||||||
image_record = self.__invoker.services.image_records.get(image_name)
|
return self.__invoker.services.image_records.get_metadata(image_name)
|
||||||
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
except ImageRecordNotFoundException:
|
||||||
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
if not image_record.session_id:
|
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
|
||||||
return ImageMetadata(metadata=metadata)
|
try:
|
||||||
|
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
|
||||||
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
|
if workflow_id is None:
|
||||||
graph = None
|
return None
|
||||||
|
return self.__invoker.services.workflow_records.get(workflow_id)
|
||||||
if session_raw:
|
|
||||||
try:
|
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
|
|
||||||
graph = None
|
|
||||||
|
|
||||||
return ImageMetadata(graph=graph, metadata=metadata)
|
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self.__invoker.services.logger.error("Image record not found")
|
self.__invoker.services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
@ -215,10 +220,11 @@ class ImageService(ImageServiceABC):
|
|||||||
image_dtos = list(
|
image_dtos = list(
|
||||||
map(
|
map(
|
||||||
lambda r: image_record_to_dto(
|
lambda r: image_record_to_dto(
|
||||||
r,
|
image_record=r,
|
||||||
self.__invoker.services.urls.get_image_url(r.image_name),
|
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||||
self.__invoker.services.urls.get_image_url(r.image_name, True),
|
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||||
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||||
|
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,8 @@ if TYPE_CHECKING:
|
|||||||
from .session_queue.session_queue_base import SessionQueueBase
|
from .session_queue.session_queue_base import SessionQueueBase
|
||||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||||
from .urls.urls_base import UrlServiceBase
|
from .urls.urls_base import UrlServiceBase
|
||||||
|
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||||
|
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
@ -55,6 +57,8 @@ class InvocationServices:
|
|||||||
invocation_cache: "InvocationCacheBase"
|
invocation_cache: "InvocationCacheBase"
|
||||||
names: "NameServiceBase"
|
names: "NameServiceBase"
|
||||||
urls: "UrlServiceBase"
|
urls: "UrlServiceBase"
|
||||||
|
workflow_image_records: "WorkflowImageRecordsStorageBase"
|
||||||
|
workflow_records: "WorkflowRecordsStorageBase"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -80,6 +84,8 @@ class InvocationServices:
|
|||||||
invocation_cache: "InvocationCacheBase",
|
invocation_cache: "InvocationCacheBase",
|
||||||
names: "NameServiceBase",
|
names: "NameServiceBase",
|
||||||
urls: "UrlServiceBase",
|
urls: "UrlServiceBase",
|
||||||
|
workflow_image_records: "WorkflowImageRecordsStorageBase",
|
||||||
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
@ -103,3 +109,5 @@ class InvocationServices:
|
|||||||
self.invocation_cache = invocation_cache
|
self.invocation_cache = invocation_cache
|
||||||
self.names = names
|
self.names = names
|
||||||
self.urls = urls
|
self.urls = urls
|
||||||
|
self.workflow_image_records = workflow_image_records
|
||||||
|
self.workflow_records = workflow_records
|
||||||
|
@ -18,7 +18,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: threading.RLock
|
_lock: threading.RLock
|
||||||
_adapter: Optional[TypeAdapter[T]]
|
_validator: Optional[TypeAdapter[T]]
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -28,7 +28,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
self._adapter: Optional[TypeAdapter[T]] = None
|
self._validator: Optional[TypeAdapter[T]] = None
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
|
||||||
@ -47,14 +47,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
if self._adapter is None:
|
if self._validator is None:
|
||||||
"""
|
"""
|
||||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||||
we can create it when it is first needed instead.
|
we can create it when it is first needed instead.
|
||||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||||
"""
|
"""
|
||||||
self._adapter = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||||
return self._adapter.validate_json(item)
|
return self._validator.validate_json(item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
try:
|
try:
|
||||||
|
@ -147,20 +147,20 @@ DEFAULT_QUEUE_ID = "default"
|
|||||||
|
|
||||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||||
|
|
||||||
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||||
|
|
||||||
|
|
||||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||||
field_values_raw = queue_item_dict.get("field_values", None)
|
field_values_raw = queue_item_dict.get("field_values", None)
|
||||||
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
|
||||||
|
|
||||||
|
|
||||||
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||||
|
|
||||||
|
|
||||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||||
session_raw = queue_item_dict.get("session", "{}")
|
session_raw = queue_item_dict.get("session", "{}")
|
||||||
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@ -193,7 +193,7 @@ class GraphInvocation(BaseInvocation):
|
|||||||
"""Execute a graph"""
|
"""Execute a graph"""
|
||||||
|
|
||||||
# TODO: figure out how to create a default here
|
# TODO: figure out how to create a default here
|
||||||
graph: "Graph" = Field(description="The graph to run", default=None)
|
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||||
"""Invoke with provided services and return outputs."""
|
"""Invoke with provided services and return outputs."""
|
||||||
@ -439,6 +439,14 @@ class Graph(BaseModel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
|
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
|
||||||
|
|
||||||
|
def _is_destination_field_Any(self, edge: Edge) -> bool:
|
||||||
|
"""Checks if the destination field for an edge is of type typing.Any"""
|
||||||
|
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any
|
||||||
|
|
||||||
|
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
|
||||||
|
"""Checks if the destination field for an edge is of type typing.Any"""
|
||||||
|
return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
|
||||||
|
|
||||||
def _validate_edge(self, edge: Edge):
|
def _validate_edge(self, edge: Edge):
|
||||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||||
|
|
||||||
@ -491,8 +499,19 @@ class Graph(BaseModel):
|
|||||||
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate that we are not connecting collector to iterator (currently unsupported)
|
||||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
|
||||||
|
raise InvalidEdgeError(
|
||||||
|
f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
|
||||||
|
if (
|
||||||
|
isinstance(from_node, CollectInvocation)
|
||||||
|
and edge.source.field == "collection"
|
||||||
|
and not self._is_destination_field_list_of_Any(edge)
|
||||||
|
and not self._is_destination_field_Any(edge)
|
||||||
|
):
|
||||||
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||||
raise InvalidEdgeError(
|
raise InvalidEdgeError(
|
||||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
@ -725,16 +744,15 @@ class Graph(BaseModel):
|
|||||||
# Get the input root type
|
# Get the input root type
|
||||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||||
|
|
||||||
# Verify that all outputs are lists
|
|
||||||
# if not all((get_origin(f) == list for f in output_fields)):
|
|
||||||
# return False
|
|
||||||
|
|
||||||
# Verify that all outputs are lists
|
# Verify that all outputs are lists
|
||||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify that all outputs match the input type (are a base class or the same class)
|
# Verify that all outputs match the input type (are a base class or the same class)
|
||||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
if not all(
|
||||||
|
is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0])
|
||||||
|
for f in output_fields
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowImageRecordsStorageBase(ABC):
|
||||||
|
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Creates a workflow-image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_workflow_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's workflow id, if it has one."""
|
||||||
|
pass
|
@ -0,0 +1,122 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
|
||||||
|
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
|
||||||
|
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.RLock
|
||||||
|
|
||||||
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = db.lock
|
||||||
|
self._conn = db.conn
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._create_tables()
|
||||||
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
# Create the `workflow_images` junction table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||||
|
workflow_id TEXT NOT NULL,
|
||||||
|
image_name TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
-- enforce one-to-many relationship between workflows and images using PK
|
||||||
|
-- (we can extend this to many-to-many later)
|
||||||
|
PRIMARY KEY (image_name),
|
||||||
|
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for workflow id
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for workflow id, sorted by created_at
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON workflow_images FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Creates a workflow-image record."""
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO workflow_images (workflow_id, image_name)
|
||||||
|
VALUES (?, ?);
|
||||||
|
""",
|
||||||
|
(workflow_id, image_name),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def get_workflow_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's workflow id, if it has one."""
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT workflow_id
|
||||||
|
FROM workflow_images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
result = self._cursor.fetchone()
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return cast(str, result[0])
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
0
invokeai/app/services/workflow_records/__init__.py
Normal file
0
invokeai/app/services/workflow_records/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRecordsStorageBase(ABC):
|
||||||
|
"""Base class for workflow storage services."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, workflow_id: str) -> WorkflowField:
|
||||||
|
"""Get workflow by id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||||
|
"""Creates a workflow."""
|
||||||
|
pass
|
@ -0,0 +1,2 @@
|
|||||||
|
class WorkflowNotFoundError(Exception):
|
||||||
|
"""Raised when a workflow is not found"""
|
@ -0,0 +1,102 @@
|
|||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||||
|
_invoker: Invoker
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.RLock
|
||||||
|
|
||||||
|
def __init__(self, db: SqliteDatabase) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = db.lock
|
||||||
|
self._conn = db.conn
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._create_tables()
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
|
def get(self, workflow_id: str) -> WorkflowField:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT workflow
|
||||||
|
FROM workflows
|
||||||
|
WHERE workflow_id = ?;
|
||||||
|
""",
|
||||||
|
(workflow_id,),
|
||||||
|
)
|
||||||
|
row = self._cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||||
|
return WorkflowFieldValidator.validate_json(row[0])
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||||
|
try:
|
||||||
|
# workflows do not have ids until they are saved
|
||||||
|
workflow_id = uuid_string()
|
||||||
|
workflow.root["id"] = workflow_id
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO workflows(workflow)
|
||||||
|
VALUES (?);
|
||||||
|
""",
|
||||||
|
(workflow.json(),),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(workflow_id)
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS workflows (
|
||||||
|
workflow TEXT NOT NULL,
|
||||||
|
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON workflows FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE workflows
|
||||||
|
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE workflow_id = old.workflow_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
@ -59,6 +59,8 @@ export type AppConfig = {
|
|||||||
nodesAllowlist: string[] | undefined;
|
nodesAllowlist: string[] | undefined;
|
||||||
nodesDenylist: string[] | undefined;
|
nodesDenylist: string[] | undefined;
|
||||||
maxUpscalePixels?: number;
|
maxUpscalePixels?: number;
|
||||||
|
metadataFetchDebounce?: number;
|
||||||
|
workflowFetchDebounce?: number;
|
||||||
sd: {
|
sd: {
|
||||||
defaultModel?: string;
|
defaultModel?: string;
|
||||||
disabledControlNetModels: string[];
|
disabledControlNetModels: string[];
|
||||||
|
@ -37,7 +37,12 @@ const useColorPicker = () => {
|
|||||||
1
|
1
|
||||||
).data;
|
).data;
|
||||||
|
|
||||||
if (!(a && r && g && b)) {
|
if (
|
||||||
|
r === undefined ||
|
||||||
|
g === undefined ||
|
||||||
|
b === undefined ||
|
||||||
|
a === undefined
|
||||||
|
) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ import {
|
|||||||
setShouldShowImageDetails,
|
setShouldShowImageDetails,
|
||||||
setShouldShowProgressInViewer,
|
setShouldShowProgressInViewer,
|
||||||
} from 'features/ui/store/uiSlice';
|
} from 'features/ui/store/uiSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -38,10 +38,9 @@ import {
|
|||||||
FaSeedling,
|
FaSeedling,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { FaCircleNodes, FaEllipsis } from 'react-icons/fa6';
|
import { FaCircleNodes, FaEllipsis } from 'react-icons/fa6';
|
||||||
import {
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
useGetImageDTOQuery,
|
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||||
useGetImageMetadataFromFileQuery,
|
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
|
||||||
} from 'services/api/endpoints/images';
|
|
||||||
import { menuListMotionProps } from 'theme/components/menu';
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
import { sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToImg2Img } from '../../store/actions';
|
||||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||||
@ -89,7 +88,6 @@ const CurrentImageButtons = () => {
|
|||||||
shouldShowImageDetails,
|
shouldShowImageDetails,
|
||||||
lastSelectedImage,
|
lastSelectedImage,
|
||||||
shouldShowProgressInViewer,
|
shouldShowProgressInViewer,
|
||||||
shouldFetchMetadataFromApi,
|
|
||||||
} = useAppSelector(currentImageButtonsSelector);
|
} = useAppSelector(currentImageButtonsSelector);
|
||||||
|
|
||||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||||
@ -104,23 +102,12 @@ const CurrentImageButtons = () => {
|
|||||||
lastSelectedImage?.image_name ?? skipToken
|
lastSelectedImage?.image_name ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
const getMetadataArg = useMemo(() => {
|
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(
|
||||||
if (lastSelectedImage) {
|
lastSelectedImage?.image_name
|
||||||
return { image: lastSelectedImage, shouldFetchMetadataFromApi };
|
);
|
||||||
} else {
|
|
||||||
return skipToken;
|
|
||||||
}
|
|
||||||
}, [lastSelectedImage, shouldFetchMetadataFromApi]);
|
|
||||||
|
|
||||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
const { workflow, isLoading: isLoadingWorkflow } = useDebouncedWorkflow(
|
||||||
getMetadataArg,
|
lastSelectedImage?.workflow_id
|
||||||
{
|
|
||||||
selectFromResult: (res) => ({
|
|
||||||
isLoading: res.isFetching,
|
|
||||||
metadata: res?.currentData?.metadata,
|
|
||||||
workflow: res?.currentData?.workflow,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleLoadWorkflow = useCallback(() => {
|
const handleLoadWorkflow = useCallback(() => {
|
||||||
@ -257,7 +244,7 @@ const CurrentImageButtons = () => {
|
|||||||
|
|
||||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
isLoading={isLoading}
|
isLoading={isLoadingWorkflow}
|
||||||
icon={<FaCircleNodes />}
|
icon={<FaCircleNodes />}
|
||||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||||
@ -265,7 +252,7 @@ const CurrentImageButtons = () => {
|
|||||||
onClick={handleLoadWorkflow}
|
onClick={handleLoadWorkflow}
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
isLoading={isLoading}
|
isLoading={isLoadingMetadata}
|
||||||
icon={<FaQuoteRight />}
|
icon={<FaQuoteRight />}
|
||||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||||
@ -273,7 +260,7 @@ const CurrentImageButtons = () => {
|
|||||||
onClick={handleUsePrompt}
|
onClick={handleUsePrompt}
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
isLoading={isLoading}
|
isLoading={isLoadingMetadata}
|
||||||
icon={<FaSeedling />}
|
icon={<FaSeedling />}
|
||||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||||
@ -281,7 +268,7 @@ const CurrentImageButtons = () => {
|
|||||||
onClick={handleUseSeed}
|
onClick={handleUseSeed}
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
isLoading={isLoading}
|
isLoading={isLoadingMetadata}
|
||||||
icon={<FaAsterisk />}
|
icon={<FaAsterisk />}
|
||||||
tooltip={`${t('parameters.useAll')} (A)`}
|
tooltip={`${t('parameters.useAll')} (A)`}
|
||||||
aria-label={`${t('parameters.useAll')} (A)`}
|
aria-label={`${t('parameters.useAll')} (A)`}
|
||||||
|
@ -2,7 +2,7 @@ import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import {
|
import {
|
||||||
imagesToChangeSelected,
|
imagesToChangeSelected,
|
||||||
@ -32,12 +32,12 @@ import {
|
|||||||
import { FaCircleNodes } from 'react-icons/fa6';
|
import { FaCircleNodes } from 'react-icons/fa6';
|
||||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
import { MdStar, MdStarBorder } from 'react-icons/md';
|
||||||
import {
|
import {
|
||||||
useGetImageMetadataFromFileQuery,
|
|
||||||
useStarImagesMutation,
|
useStarImagesMutation,
|
||||||
useUnstarImagesMutation,
|
useUnstarImagesMutation,
|
||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
|
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||||
|
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { configSelector } from '../../../system/store/configSelectors';
|
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||||
|
|
||||||
type SingleSelectionMenuItemsProps = {
|
type SingleSelectionMenuItemsProps = {
|
||||||
@ -53,18 +53,13 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
|
|
||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
|
|
||||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(
|
||||||
{ image: imageDTO, shouldFetchMetadataFromApi },
|
imageDTO?.image_name
|
||||||
{
|
);
|
||||||
selectFromResult: (res) => ({
|
const { workflow, isLoading: isLoadingWorkflow } = useDebouncedWorkflow(
|
||||||
isLoading: res.isFetching,
|
imageDTO?.workflow_id
|
||||||
metadata: res?.currentData?.metadata,
|
|
||||||
workflow: res?.currentData?.workflow,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const [starImages] = useStarImagesMutation();
|
const [starImages] = useStarImagesMutation();
|
||||||
@ -181,17 +176,17 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
{t('parameters.downloadImage')}
|
{t('parameters.downloadImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={isLoading ? <SpinnerIcon /> : <FaCircleNodes />}
|
icon={isLoadingWorkflow ? <SpinnerIcon /> : <FaCircleNodes />}
|
||||||
onClickCapture={handleLoadWorkflow}
|
onClickCapture={handleLoadWorkflow}
|
||||||
isDisabled={isLoading || !workflow}
|
isDisabled={isLoadingWorkflow || !workflow}
|
||||||
>
|
>
|
||||||
{t('nodes.loadWorkflow')}
|
{t('nodes.loadWorkflow')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
|
icon={isLoadingMetadata ? <SpinnerIcon /> : <FaQuoteRight />}
|
||||||
onClickCapture={handleRecallPrompt}
|
onClickCapture={handleRecallPrompt}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
isLoading ||
|
isLoadingMetadata ||
|
||||||
(metadata?.positive_prompt === undefined &&
|
(metadata?.positive_prompt === undefined &&
|
||||||
metadata?.negative_prompt === undefined)
|
metadata?.negative_prompt === undefined)
|
||||||
}
|
}
|
||||||
@ -199,16 +194,16 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
{t('parameters.usePrompt')}
|
{t('parameters.usePrompt')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
|
icon={isLoadingMetadata ? <SpinnerIcon /> : <FaSeedling />}
|
||||||
onClickCapture={handleRecallSeed}
|
onClickCapture={handleRecallSeed}
|
||||||
isDisabled={isLoading || metadata?.seed === undefined}
|
isDisabled={isLoadingMetadata || metadata?.seed === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.useSeed')}
|
{t('parameters.useSeed')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
|
icon={isLoadingMetadata ? <SpinnerIcon /> : <FaAsterisk />}
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={isLoading || !metadata}
|
isDisabled={isLoadingMetadata || !metadata}
|
||||||
>
|
>
|
||||||
{t('parameters.useAll')}
|
{t('parameters.useAll')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
@ -10,15 +10,14 @@ import {
|
|||||||
Text,
|
Text,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||||
|
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import DataViewer from './DataViewer';
|
import DataViewer from './DataViewer';
|
||||||
import ImageMetadataActions from './ImageMetadataActions';
|
import ImageMetadataActions from './ImageMetadataActions';
|
||||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
|
||||||
import { configSelector } from '../../../system/store/configSelectors';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
|
|
||||||
|
|
||||||
type ImageMetadataViewerProps = {
|
type ImageMetadataViewerProps = {
|
||||||
image: ImageDTO;
|
image: ImageDTO;
|
||||||
@ -32,17 +31,8 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
// });
|
// });
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
|
const { metadata } = useDebouncedMetadata(image.image_name);
|
||||||
|
const { workflow } = useDebouncedWorkflow(image.workflow_id);
|
||||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
|
||||||
{ image, shouldFetchMetadataFromApi },
|
|
||||||
{
|
|
||||||
selectFromResult: (res) => ({
|
|
||||||
metadata: res?.currentData?.metadata,
|
|
||||||
workflow: res?.currentData?.workflow,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
|
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
|
||||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
import { useWithWorkflow } from 'features/nodes/hooks/useWithWorkflow';
|
||||||
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
|
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
|
||||||
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
const withWorkflow = useWithWorkflow(nodeId);
|
||||||
const embedWorkflow = useEmbedWorkflow(nodeId);
|
const embedWorkflow = useEmbedWorkflow(nodeId);
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(e: ChangeEvent<HTMLInputElement>) => {
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
@ -21,7 +21,7 @@ const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
|||||||
[dispatch, nodeId]
|
[dispatch, nodeId]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!hasImageOutput) {
|
if (!withWorkflow) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus';
|
||||||
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
||||||
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
||||||
import UseCacheCheckbox from './UseCacheCheckbox';
|
import UseCacheCheckbox from './UseCacheCheckbox';
|
||||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
|
||||||
import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus';
|
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { isInvocationNode } from '../types/types';
|
||||||
|
|
||||||
|
export const useWithWorkflow = (nodeId: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => {
|
||||||
|
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||||
|
if (!nodeTemplate) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return nodeTemplate.withWorkflow;
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const withWorkflow = useAppSelector(selector);
|
||||||
|
return withWorkflow;
|
||||||
|
};
|
@ -69,6 +69,8 @@ export const validateSourceAndTargetTypes = (
|
|||||||
(sourceType === 'integer' || sourceType === 'float') &&
|
(sourceType === 'integer' || sourceType === 'float') &&
|
||||||
targetType === 'string';
|
targetType === 'string';
|
||||||
|
|
||||||
|
const isTargetAnyType = targetType === 'Any';
|
||||||
|
|
||||||
return (
|
return (
|
||||||
isCollectionItemToNonCollection ||
|
isCollectionItemToNonCollection ||
|
||||||
isNonCollectionToCollectionItem ||
|
isNonCollectionToCollectionItem ||
|
||||||
@ -76,6 +78,7 @@ export const validateSourceAndTargetTypes = (
|
|||||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||||
isCollectionToGenericCollection ||
|
isCollectionToGenericCollection ||
|
||||||
isIntToFloat ||
|
isIntToFloat ||
|
||||||
isIntOrFloatToString
|
isIntOrFloatToString ||
|
||||||
|
isTargetAnyType
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -33,6 +33,8 @@ export const COLLECTION_TYPES: FieldType[] = [
|
|||||||
'ColorCollection',
|
'ColorCollection',
|
||||||
'T2IAdapterCollection',
|
'T2IAdapterCollection',
|
||||||
'IPAdapterCollection',
|
'IPAdapterCollection',
|
||||||
|
'MetadataItemCollection',
|
||||||
|
'MetadataCollection',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const POLYMORPHIC_TYPES: FieldType[] = [
|
export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||||
@ -47,6 +49,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
|
|||||||
'ColorPolymorphic',
|
'ColorPolymorphic',
|
||||||
'T2IAdapterPolymorphic',
|
'T2IAdapterPolymorphic',
|
||||||
'IPAdapterPolymorphic',
|
'IPAdapterPolymorphic',
|
||||||
|
'MetadataItemPolymorphic',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const MODEL_TYPES: FieldType[] = [
|
export const MODEL_TYPES: FieldType[] = [
|
||||||
@ -78,6 +81,8 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
|
|||||||
ColorField: 'ColorCollection',
|
ColorField: 'ColorCollection',
|
||||||
T2IAdapterField: 'T2IAdapterCollection',
|
T2IAdapterField: 'T2IAdapterCollection',
|
||||||
IPAdapterField: 'IPAdapterCollection',
|
IPAdapterField: 'IPAdapterCollection',
|
||||||
|
MetadataItemField: 'MetadataItemCollection',
|
||||||
|
MetadataField: 'MetadataCollection',
|
||||||
};
|
};
|
||||||
export const isCollectionItemType = (
|
export const isCollectionItemType = (
|
||||||
itemType: string | undefined
|
itemType: string | undefined
|
||||||
@ -97,6 +102,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
|||||||
ColorField: 'ColorPolymorphic',
|
ColorField: 'ColorPolymorphic',
|
||||||
T2IAdapterField: 'T2IAdapterPolymorphic',
|
T2IAdapterField: 'T2IAdapterPolymorphic',
|
||||||
IPAdapterField: 'IPAdapterPolymorphic',
|
IPAdapterField: 'IPAdapterPolymorphic',
|
||||||
|
MetadataItemField: 'MetadataItemPolymorphic',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||||
@ -111,6 +117,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
|||||||
ColorPolymorphic: 'ColorField',
|
ColorPolymorphic: 'ColorField',
|
||||||
T2IAdapterPolymorphic: 'T2IAdapterField',
|
T2IAdapterPolymorphic: 'T2IAdapterField',
|
||||||
IPAdapterPolymorphic: 'IPAdapterField',
|
IPAdapterPolymorphic: 'IPAdapterField',
|
||||||
|
MetadataItemPolymorphic: 'MetadataItemField',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||||
@ -144,6 +151,37 @@ export const isPolymorphicItemType = (
|
|||||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||||
|
|
||||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||||
|
Any: {
|
||||||
|
color: 'gray.500',
|
||||||
|
description: 'Any field type is accepted.',
|
||||||
|
title: 'Any',
|
||||||
|
},
|
||||||
|
MetadataField: {
|
||||||
|
color: 'gray.500',
|
||||||
|
description: 'A metadata dict.',
|
||||||
|
title: 'Metadata Dict',
|
||||||
|
},
|
||||||
|
MetadataCollection: {
|
||||||
|
color: 'gray.500',
|
||||||
|
description: 'A collection of metadata dicts.',
|
||||||
|
title: 'Metadata Dict Collection',
|
||||||
|
},
|
||||||
|
MetadataItemField: {
|
||||||
|
color: 'gray.500',
|
||||||
|
description: 'A metadata item.',
|
||||||
|
title: 'Metadata Item',
|
||||||
|
},
|
||||||
|
MetadataItemCollection: {
|
||||||
|
color: 'gray.500',
|
||||||
|
description: 'Any field type is accepted.',
|
||||||
|
title: 'Metadata Item Collection',
|
||||||
|
},
|
||||||
|
MetadataItemPolymorphic: {
|
||||||
|
color: 'gray.500',
|
||||||
|
description:
|
||||||
|
'MetadataItem or MetadataItemCollection field types are accepted.',
|
||||||
|
title: 'Metadata Item Polymorphic',
|
||||||
|
},
|
||||||
boolean: {
|
boolean: {
|
||||||
color: 'green.500',
|
color: 'green.500',
|
||||||
description: t('nodes.booleanDescription'),
|
description: t('nodes.booleanDescription'),
|
||||||
|
@ -54,6 +54,10 @@ export type InvocationTemplate = {
|
|||||||
* The type of this node's output
|
* The type of this node's output
|
||||||
*/
|
*/
|
||||||
outputType: string; // TODO: generate a union of output types
|
outputType: string; // TODO: generate a union of output types
|
||||||
|
/**
|
||||||
|
* Whether or not this invocation supports workflows
|
||||||
|
*/
|
||||||
|
withWorkflow: boolean;
|
||||||
/**
|
/**
|
||||||
* The invocation's version.
|
* The invocation's version.
|
||||||
*/
|
*/
|
||||||
@ -72,6 +76,7 @@ export type FieldUIConfig = {
|
|||||||
|
|
||||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||||
export const zFieldType = z.enum([
|
export const zFieldType = z.enum([
|
||||||
|
'Any',
|
||||||
'BoardField',
|
'BoardField',
|
||||||
'boolean',
|
'boolean',
|
||||||
'BooleanCollection',
|
'BooleanCollection',
|
||||||
@ -109,6 +114,11 @@ export const zFieldType = z.enum([
|
|||||||
'LatentsPolymorphic',
|
'LatentsPolymorphic',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
|
'MetadataField',
|
||||||
|
'MetadataCollection',
|
||||||
|
'MetadataItemField',
|
||||||
|
'MetadataItemCollection',
|
||||||
|
'MetadataItemPolymorphic',
|
||||||
'ONNXModelField',
|
'ONNXModelField',
|
||||||
'Scheduler',
|
'Scheduler',
|
||||||
'SDXLMainModelField',
|
'SDXLMainModelField',
|
||||||
@ -685,6 +695,57 @@ export type CollectionItemInputFieldValue = z.infer<
|
|||||||
typeof zCollectionItemInputFieldValue
|
typeof zCollectionItemInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zMetadataItemField = z.object({
|
||||||
|
label: z.string(),
|
||||||
|
value: z.any(),
|
||||||
|
});
|
||||||
|
export type MetadataItemField = z.infer<typeof zMetadataItemField>;
|
||||||
|
|
||||||
|
export const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('MetadataItemField'),
|
||||||
|
value: zMetadataItemField.optional(),
|
||||||
|
});
|
||||||
|
export type MetadataItemInputFieldValue = z.infer<
|
||||||
|
typeof zMetadataItemInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zMetadataItemCollectionInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('MetadataItemCollection'),
|
||||||
|
value: z.array(zMetadataItemField).optional(),
|
||||||
|
});
|
||||||
|
export type MetadataItemCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zMetadataItemCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zMetadataItemPolymorphicInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('MetadataItemPolymorphic'),
|
||||||
|
value: z
|
||||||
|
.union([zMetadataItemField, z.array(zMetadataItemField)])
|
||||||
|
.optional(),
|
||||||
|
});
|
||||||
|
export type MetadataItemPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zMetadataItemPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zMetadataField = z.record(z.any());
|
||||||
|
export type MetadataField = z.infer<typeof zMetadataField>;
|
||||||
|
|
||||||
|
export const zMetadataInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('MetadataField'),
|
||||||
|
value: zMetadataField.optional(),
|
||||||
|
});
|
||||||
|
export type MetadataInputFieldValue = z.infer<typeof zMetadataInputFieldValue>;
|
||||||
|
|
||||||
|
export const zMetadataCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('MetadataCollection'),
|
||||||
|
value: z.array(zMetadataField).optional(),
|
||||||
|
});
|
||||||
|
export type MetadataCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zMetadataCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zColorField = z.object({
|
export const zColorField = z.object({
|
||||||
r: z.number().int().min(0).max(255),
|
r: z.number().int().min(0).max(255),
|
||||||
g: z.number().int().min(0).max(255),
|
g: z.number().int().min(0).max(255),
|
||||||
@ -723,7 +784,13 @@ export type SchedulerInputFieldValue = z.infer<
|
|||||||
typeof zSchedulerInputFieldValue
|
typeof zSchedulerInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zAnyInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('Any'),
|
||||||
|
value: z.any().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||||
|
zAnyInputFieldValue,
|
||||||
zBoardInputFieldValue,
|
zBoardInputFieldValue,
|
||||||
zBooleanCollectionInputFieldValue,
|
zBooleanCollectionInputFieldValue,
|
||||||
zBooleanInputFieldValue,
|
zBooleanInputFieldValue,
|
||||||
@ -774,6 +841,11 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
|||||||
zUNetInputFieldValue,
|
zUNetInputFieldValue,
|
||||||
zVaeInputFieldValue,
|
zVaeInputFieldValue,
|
||||||
zVaeModelInputFieldValue,
|
zVaeModelInputFieldValue,
|
||||||
|
zMetadataItemInputFieldValue,
|
||||||
|
zMetadataItemCollectionInputFieldValue,
|
||||||
|
zMetadataItemPolymorphicInputFieldValue,
|
||||||
|
zMetadataInputFieldValue,
|
||||||
|
zMetadataCollectionInputFieldValue,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||||
@ -786,6 +858,11 @@ export type InputFieldTemplateBase = {
|
|||||||
fieldKind: 'input';
|
fieldKind: 'input';
|
||||||
} & _InputField;
|
} & _InputField;
|
||||||
|
|
||||||
|
export type AnyInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'Any';
|
||||||
|
default: undefined;
|
||||||
|
};
|
||||||
|
|
||||||
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
type: 'integer';
|
type: 'integer';
|
||||||
default: number;
|
default: number;
|
||||||
@ -939,6 +1016,11 @@ export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'UNetField';
|
type: 'UNetField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type MetadataItemFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'MetadataItemField';
|
||||||
|
};
|
||||||
|
|
||||||
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
export type ClipInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'ClipField';
|
type: 'ClipField';
|
||||||
@ -1087,6 +1169,34 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'WorkflowField';
|
type: 'WorkflowField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type MetadataItemInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'MetadataItemField';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetadataItemCollectionInputFieldTemplate =
|
||||||
|
InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'MetadataItemCollection';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetadataItemPolymorphicInputFieldTemplate = Omit<
|
||||||
|
MetadataItemInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'MetadataItemPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetadataInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'MetadataField';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetadataCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'MetadataCollection';
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An input field template is generated on each page load from the OpenAPI schema.
|
* An input field template is generated on each page load from the OpenAPI schema.
|
||||||
*
|
*
|
||||||
@ -1094,6 +1204,7 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
* maximum length, pattern to match, etc).
|
* maximum length, pattern to match, etc).
|
||||||
*/
|
*/
|
||||||
export type InputFieldTemplate =
|
export type InputFieldTemplate =
|
||||||
|
| AnyInputFieldTemplate
|
||||||
| BoardInputFieldTemplate
|
| BoardInputFieldTemplate
|
||||||
| BooleanCollectionInputFieldTemplate
|
| BooleanCollectionInputFieldTemplate
|
||||||
| BooleanPolymorphicInputFieldTemplate
|
| BooleanPolymorphicInputFieldTemplate
|
||||||
@ -1143,7 +1254,12 @@ export type InputFieldTemplate =
|
|||||||
| T2IAdapterPolymorphicInputFieldTemplate
|
| T2IAdapterPolymorphicInputFieldTemplate
|
||||||
| UNetInputFieldTemplate
|
| UNetInputFieldTemplate
|
||||||
| VaeInputFieldTemplate
|
| VaeInputFieldTemplate
|
||||||
| VaeModelInputFieldTemplate;
|
| VaeModelInputFieldTemplate
|
||||||
|
| MetadataItemInputFieldTemplate
|
||||||
|
| MetadataItemCollectionInputFieldTemplate
|
||||||
|
| MetadataInputFieldTemplate
|
||||||
|
| MetadataItemPolymorphicInputFieldTemplate
|
||||||
|
| MetadataCollectionInputFieldTemplate;
|
||||||
|
|
||||||
export const isInputFieldValue = (
|
export const isInputFieldValue = (
|
||||||
field?: InputFieldValue | OutputFieldValue
|
field?: InputFieldValue | OutputFieldValue
|
||||||
@ -1264,7 +1380,7 @@ export const isInvocationFieldSchema = (
|
|||||||
|
|
||||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||||
|
|
||||||
const zLoRAMetadataItem = z.object({
|
export const zLoRAMetadataItem = z.object({
|
||||||
lora: zLoRAModelField.deepPartial(),
|
lora: zLoRAModelField.deepPartial(),
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
});
|
});
|
||||||
|
@ -7,6 +7,7 @@ import {
|
|||||||
startCase,
|
startCase,
|
||||||
} from 'lodash-es';
|
} from 'lodash-es';
|
||||||
import { OpenAPIV3_1 } from 'openapi-types';
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
|
import { ControlField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
COLLECTION_MAP,
|
COLLECTION_MAP,
|
||||||
POLYMORPHIC_TYPES,
|
POLYMORPHIC_TYPES,
|
||||||
@ -15,36 +16,70 @@ import {
|
|||||||
isPolymorphicItemType,
|
isPolymorphicItemType,
|
||||||
} from '../types/constants';
|
} from '../types/constants';
|
||||||
import {
|
import {
|
||||||
|
AnyInputFieldTemplate,
|
||||||
|
BoardInputFieldTemplate,
|
||||||
BooleanCollectionInputFieldTemplate,
|
BooleanCollectionInputFieldTemplate,
|
||||||
BooleanInputFieldTemplate,
|
BooleanInputFieldTemplate,
|
||||||
|
BooleanPolymorphicInputFieldTemplate,
|
||||||
ClipInputFieldTemplate,
|
ClipInputFieldTemplate,
|
||||||
CollectionInputFieldTemplate,
|
CollectionInputFieldTemplate,
|
||||||
CollectionItemInputFieldTemplate,
|
CollectionItemInputFieldTemplate,
|
||||||
|
ColorCollectionInputFieldTemplate,
|
||||||
ColorInputFieldTemplate,
|
ColorInputFieldTemplate,
|
||||||
|
ColorPolymorphicInputFieldTemplate,
|
||||||
|
ConditioningCollectionInputFieldTemplate,
|
||||||
|
ConditioningField,
|
||||||
ConditioningInputFieldTemplate,
|
ConditioningInputFieldTemplate,
|
||||||
|
ConditioningPolymorphicInputFieldTemplate,
|
||||||
|
ControlCollectionInputFieldTemplate,
|
||||||
ControlInputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
ControlNetModelInputFieldTemplate,
|
ControlNetModelInputFieldTemplate,
|
||||||
|
ControlPolymorphicInputFieldTemplate,
|
||||||
DenoiseMaskInputFieldTemplate,
|
DenoiseMaskInputFieldTemplate,
|
||||||
EnumInputFieldTemplate,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
FloatCollectionInputFieldTemplate,
|
FloatCollectionInputFieldTemplate,
|
||||||
FloatPolymorphicInputFieldTemplate,
|
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
|
FloatPolymorphicInputFieldTemplate,
|
||||||
|
IPAdapterCollectionInputFieldTemplate,
|
||||||
|
IPAdapterField,
|
||||||
|
IPAdapterInputFieldTemplate,
|
||||||
|
IPAdapterModelInputFieldTemplate,
|
||||||
|
IPAdapterPolymorphicInputFieldTemplate,
|
||||||
ImageCollectionInputFieldTemplate,
|
ImageCollectionInputFieldTemplate,
|
||||||
|
ImageField,
|
||||||
ImageInputFieldTemplate,
|
ImageInputFieldTemplate,
|
||||||
|
ImagePolymorphicInputFieldTemplate,
|
||||||
|
InputFieldTemplate,
|
||||||
InputFieldTemplateBase,
|
InputFieldTemplateBase,
|
||||||
IntegerCollectionInputFieldTemplate,
|
IntegerCollectionInputFieldTemplate,
|
||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
|
IntegerPolymorphicInputFieldTemplate,
|
||||||
InvocationFieldSchema,
|
InvocationFieldSchema,
|
||||||
InvocationSchemaObject,
|
InvocationSchemaObject,
|
||||||
|
LatentsCollectionInputFieldTemplate,
|
||||||
|
LatentsField,
|
||||||
LatentsInputFieldTemplate,
|
LatentsInputFieldTemplate,
|
||||||
|
LatentsPolymorphicInputFieldTemplate,
|
||||||
LoRAModelInputFieldTemplate,
|
LoRAModelInputFieldTemplate,
|
||||||
MainModelInputFieldTemplate,
|
MainModelInputFieldTemplate,
|
||||||
|
MetadataCollectionInputFieldTemplate,
|
||||||
|
MetadataInputFieldTemplate,
|
||||||
|
MetadataItemCollectionInputFieldTemplate,
|
||||||
|
MetadataItemInputFieldTemplate,
|
||||||
|
MetadataItemPolymorphicInputFieldTemplate,
|
||||||
|
OpenAPIV3_1SchemaOrRef,
|
||||||
SDXLMainModelInputFieldTemplate,
|
SDXLMainModelInputFieldTemplate,
|
||||||
SDXLRefinerModelInputFieldTemplate,
|
SDXLRefinerModelInputFieldTemplate,
|
||||||
SchedulerInputFieldTemplate,
|
SchedulerInputFieldTemplate,
|
||||||
StringCollectionInputFieldTemplate,
|
StringCollectionInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
|
StringPolymorphicInputFieldTemplate,
|
||||||
|
T2IAdapterCollectionInputFieldTemplate,
|
||||||
|
T2IAdapterField,
|
||||||
|
T2IAdapterInputFieldTemplate,
|
||||||
|
T2IAdapterModelInputFieldTemplate,
|
||||||
|
T2IAdapterPolymorphicInputFieldTemplate,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
VaeInputFieldTemplate,
|
VaeInputFieldTemplate,
|
||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
@ -52,36 +87,7 @@ import {
|
|||||||
isNonArraySchemaObject,
|
isNonArraySchemaObject,
|
||||||
isRefObject,
|
isRefObject,
|
||||||
isSchemaObject,
|
isSchemaObject,
|
||||||
ControlPolymorphicInputFieldTemplate,
|
|
||||||
ColorPolymorphicInputFieldTemplate,
|
|
||||||
ColorCollectionInputFieldTemplate,
|
|
||||||
IntegerPolymorphicInputFieldTemplate,
|
|
||||||
StringPolymorphicInputFieldTemplate,
|
|
||||||
BooleanPolymorphicInputFieldTemplate,
|
|
||||||
ImagePolymorphicInputFieldTemplate,
|
|
||||||
LatentsPolymorphicInputFieldTemplate,
|
|
||||||
LatentsCollectionInputFieldTemplate,
|
|
||||||
ConditioningPolymorphicInputFieldTemplate,
|
|
||||||
ConditioningCollectionInputFieldTemplate,
|
|
||||||
ControlCollectionInputFieldTemplate,
|
|
||||||
ImageField,
|
|
||||||
LatentsField,
|
|
||||||
ConditioningField,
|
|
||||||
IPAdapterField,
|
|
||||||
IPAdapterInputFieldTemplate,
|
|
||||||
IPAdapterModelInputFieldTemplate,
|
|
||||||
IPAdapterPolymorphicInputFieldTemplate,
|
|
||||||
IPAdapterCollectionInputFieldTemplate,
|
|
||||||
T2IAdapterField,
|
|
||||||
T2IAdapterInputFieldTemplate,
|
|
||||||
T2IAdapterModelInputFieldTemplate,
|
|
||||||
T2IAdapterPolymorphicInputFieldTemplate,
|
|
||||||
T2IAdapterCollectionInputFieldTemplate,
|
|
||||||
BoardInputFieldTemplate,
|
|
||||||
InputFieldTemplate,
|
|
||||||
OpenAPIV3_1SchemaOrRef,
|
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import { ControlField } from 'services/api/types';
|
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
|
|
||||||
@ -851,6 +857,78 @@ const buildCollectionItemInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildAnyInputFieldTemplate = ({
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): AnyInputFieldTemplate => {
|
||||||
|
const template: AnyInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'Any',
|
||||||
|
default: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildMetadataItemInputFieldTemplate = ({
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): MetadataItemInputFieldTemplate => {
|
||||||
|
const template: MetadataItemInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'MetadataItemField',
|
||||||
|
default: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildMetadataItemCollectionInputFieldTemplate = ({
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): MetadataItemCollectionInputFieldTemplate => {
|
||||||
|
const template: MetadataItemCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'MetadataItemCollection',
|
||||||
|
default: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildMetadataItemPolymorphicInputFieldTemplate = ({
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): MetadataItemPolymorphicInputFieldTemplate => {
|
||||||
|
const template: MetadataItemPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'MetadataItemPolymorphic',
|
||||||
|
default: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildMetadataDictInputFieldTemplate = ({
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): MetadataInputFieldTemplate => {
|
||||||
|
const template: MetadataInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'MetadataField',
|
||||||
|
default: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildMetadataCollectionInputFieldTemplate = ({
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): MetadataCollectionInputFieldTemplate => {
|
||||||
|
const template: MetadataCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'MetadataCollection',
|
||||||
|
default: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildColorInputFieldTemplate = ({
|
const buildColorInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -1012,6 +1090,7 @@ const TEMPLATE_BUILDER_MAP: {
|
|||||||
[key in FieldType]?: (arg: BuildInputFieldArg) => InputFieldTemplate;
|
[key in FieldType]?: (arg: BuildInputFieldArg) => InputFieldTemplate;
|
||||||
} = {
|
} = {
|
||||||
BoardField: buildBoardInputFieldTemplate,
|
BoardField: buildBoardInputFieldTemplate,
|
||||||
|
Any: buildAnyInputFieldTemplate,
|
||||||
boolean: buildBooleanInputFieldTemplate,
|
boolean: buildBooleanInputFieldTemplate,
|
||||||
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||||
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||||
@ -1047,6 +1126,11 @@ const TEMPLATE_BUILDER_MAP: {
|
|||||||
LatentsField: buildLatentsInputFieldTemplate,
|
LatentsField: buildLatentsInputFieldTemplate,
|
||||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||||
|
MetadataItemField: buildMetadataItemInputFieldTemplate,
|
||||||
|
MetadataItemCollection: buildMetadataItemCollectionInputFieldTemplate,
|
||||||
|
MetadataItemPolymorphic: buildMetadataItemPolymorphicInputFieldTemplate,
|
||||||
|
MetadataField: buildMetadataDictInputFieldTemplate,
|
||||||
|
MetadataCollection: buildMetadataCollectionInputFieldTemplate,
|
||||||
MainModelField: buildMainModelInputFieldTemplate,
|
MainModelField: buildMainModelInputFieldTemplate,
|
||||||
Scheduler: buildSchedulerInputFieldTemplate,
|
Scheduler: buildSchedulerInputFieldTemplate,
|
||||||
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||||
|
@ -3,6 +3,7 @@ import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types';
|
|||||||
const FIELD_VALUE_FALLBACK_MAP: {
|
const FIELD_VALUE_FALLBACK_MAP: {
|
||||||
[key in FieldType]: InputFieldValue['value'];
|
[key in FieldType]: InputFieldValue['value'];
|
||||||
} = {
|
} = {
|
||||||
|
Any: undefined,
|
||||||
enum: '',
|
enum: '',
|
||||||
BoardField: undefined,
|
BoardField: undefined,
|
||||||
boolean: false,
|
boolean: false,
|
||||||
@ -38,6 +39,11 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
|||||||
LatentsCollection: [],
|
LatentsCollection: [],
|
||||||
LatentsField: undefined,
|
LatentsField: undefined,
|
||||||
LatentsPolymorphic: undefined,
|
LatentsPolymorphic: undefined,
|
||||||
|
MetadataItemField: undefined,
|
||||||
|
MetadataItemCollection: [],
|
||||||
|
MetadataItemPolymorphic: undefined,
|
||||||
|
MetadataField: undefined,
|
||||||
|
MetadataCollection: [],
|
||||||
LoRAModelField: undefined,
|
LoRAModelField: undefined,
|
||||||
MainModelField: undefined,
|
MainModelField: undefined,
|
||||||
ONNXModelField: undefined,
|
ONNXModelField: undefined,
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
import * as png from '@stevebel/png';
|
|
||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import {
|
|
||||||
ImageMetadataAndWorkflow,
|
|
||||||
zCoreMetadata,
|
|
||||||
zWorkflow,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { get } from 'lodash-es';
|
|
||||||
|
|
||||||
export const getMetadataAndWorkflowFromImageBlob = async (
|
|
||||||
image: Blob
|
|
||||||
): Promise<ImageMetadataAndWorkflow> => {
|
|
||||||
const data: ImageMetadataAndWorkflow = {};
|
|
||||||
const buffer = await image.arrayBuffer();
|
|
||||||
const text = png.decode(buffer).text;
|
|
||||||
|
|
||||||
const rawMetadata = get(text, 'invokeai_metadata');
|
|
||||||
if (rawMetadata) {
|
|
||||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
|
||||||
if (metadataResult.success) {
|
|
||||||
data.metadata = metadataResult.data;
|
|
||||||
} else {
|
|
||||||
logger('system').error(
|
|
||||||
{ error: parseify(metadataResult.error) },
|
|
||||||
'Problem reading metadata from image'
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const rawWorkflow = get(text, 'invokeai_workflow');
|
|
||||||
if (rawWorkflow) {
|
|
||||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
|
||||||
if (workflowResult.success) {
|
|
||||||
data.workflow = workflowResult.data;
|
|
||||||
} else {
|
|
||||||
logger('system').error(
|
|
||||||
{ error: parseify(workflowResult.error) },
|
|
||||||
'Problem reading workflow from image'
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return data;
|
|
||||||
};
|
|
@ -5,14 +5,14 @@ import {
|
|||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
ControlField,
|
ControlField,
|
||||||
ControlNetInvocation,
|
ControlNetInvocation,
|
||||||
MetadataAccumulatorInvocation,
|
CoreMetadataInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../../types/types';
|
import { NonNullableGraph } from '../../types/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
CONTROL_NET_COLLECT,
|
CONTROL_NET_COLLECT,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addControlNetToLinearGraph = (
|
export const addControlNetToLinearGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -23,9 +23,11 @@ export const addControlNetToLinearGraph = (
|
|||||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||||
);
|
);
|
||||||
|
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||||
| MetadataAccumulatorInvocation
|
// | MetadataAccumulatorInvocation
|
||||||
| undefined;
|
// | undefined;
|
||||||
|
|
||||||
|
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
||||||
|
|
||||||
if (validControlNets.length) {
|
if (validControlNets.length) {
|
||||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||||
@ -99,15 +101,9 @@ export const addControlNetToLinearGraph = (
|
|||||||
|
|
||||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||||
|
|
||||||
if (metadataAccumulator?.controlnets) {
|
controlNetMetadata.push(
|
||||||
// metadata accumulator only needs a control field - not the whole node
|
omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField
|
||||||
// extract what we need and add to the accumulator
|
);
|
||||||
const controlField = omit(controlNetNode, [
|
|
||||||
'id',
|
|
||||||
'type',
|
|
||||||
]) as ControlField;
|
|
||||||
metadataAccumulator.controlnets.push(controlField);
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: controlNetNode.id, field: 'control' },
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
@ -117,5 +113,6 @@ export const addControlNetToLinearGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
upsertMetadata(graph, { controlnets: controlNetMetadata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,25 +1,25 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
DenoiseLatentsInvocation,
|
DenoiseLatentsInvocation,
|
||||||
ResizeLatentsInvocation,
|
|
||||||
NoiseInvocation,
|
|
||||||
LatentsToImageInvocation,
|
|
||||||
Edge,
|
Edge,
|
||||||
|
LatentsToImageInvocation,
|
||||||
|
NoiseInvocation,
|
||||||
|
ResizeLatentsInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
LATENTS_TO_IMAGE,
|
|
||||||
DENOISE_LATENTS,
|
DENOISE_LATENTS,
|
||||||
NOISE,
|
|
||||||
MAIN_MODEL_LOADER,
|
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
LATENTS_TO_IMAGE_HRF,
|
|
||||||
DENOISE_LATENTS_HRF,
|
DENOISE_LATENTS_HRF,
|
||||||
RESCALE_LATENTS,
|
LATENTS_TO_IMAGE,
|
||||||
|
LATENTS_TO_IMAGE_HRF,
|
||||||
|
MAIN_MODEL_LOADER,
|
||||||
|
NOISE,
|
||||||
NOISE_HRF,
|
NOISE_HRF,
|
||||||
|
RESCALE_LATENTS,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { logger } from 'app/logging/logger';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
// Copy certain connections from previous DENOISE_LATENTS to new DENOISE_LATENTS_HRF.
|
// Copy certain connections from previous DENOISE_LATENTS to new DENOISE_LATENTS_HRF.
|
||||||
function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
|
function copyConnectionsToDenoiseLatentsHrf(graph: NonNullableGraph): void {
|
||||||
@ -71,10 +71,8 @@ export const addHrfToGraph = (
|
|||||||
}
|
}
|
||||||
const log = logger('txt2img');
|
const log = logger('txt2img');
|
||||||
|
|
||||||
const { vae } = state.generation;
|
const { vae, hrfWidth, hrfHeight, hrfStrength } = state.generation;
|
||||||
const isAutoVae = !vae;
|
const isAutoVae = !vae;
|
||||||
const hrfWidth = state.generation.hrfWidth;
|
|
||||||
const hrfHeight = state.generation.hrfHeight;
|
|
||||||
|
|
||||||
// Pre-existing (original) graph nodes.
|
// Pre-existing (original) graph nodes.
|
||||||
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as
|
const originalDenoiseLatentsNode = graph.nodes[DENOISE_LATENTS] as
|
||||||
@ -116,7 +114,7 @@ export const addHrfToGraph = (
|
|||||||
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
|
cfg_scale: originalDenoiseLatentsNode?.cfg_scale,
|
||||||
scheduler: originalDenoiseLatentsNode?.scheduler,
|
scheduler: originalDenoiseLatentsNode?.scheduler,
|
||||||
steps: originalDenoiseLatentsNode?.steps,
|
steps: originalDenoiseLatentsNode?.steps,
|
||||||
denoising_start: 1 - state.generation.hrfStrength,
|
denoising_start: 1 - hrfStrength,
|
||||||
denoising_end: 1,
|
denoising_end: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -221,16 +219,6 @@ export const addHrfToGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE_HRF,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
node_id: isAutoVae ? MAIN_MODEL_LOADER : VAE_LOADER,
|
||||||
@ -243,5 +231,11 @@ export const addHrfToGraph = (
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
upsertMetadata(graph, {
|
||||||
|
hrf_height: hrfHeight,
|
||||||
|
hrf_width: hrfWidth,
|
||||||
|
hrf_strength: hrfStrength,
|
||||||
|
});
|
||||||
|
|
||||||
copyConnectionsToDenoiseLatentsHrf(graph);
|
copyConnectionsToDenoiseLatentsHrf(graph);
|
||||||
};
|
};
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { omit } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
|
CoreMetadataInvocation,
|
||||||
IPAdapterInvocation,
|
IPAdapterInvocation,
|
||||||
MetadataAccumulatorInvocation,
|
IPAdapterMetadataField,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../../types/types';
|
import { NonNullableGraph } from '../../types/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
IP_ADAPTER_COLLECT,
|
IP_ADAPTER_COLLECT,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addIPAdapterToLinearGraph = (
|
export const addIPAdapterToLinearGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -21,10 +23,6 @@ export const addIPAdapterToLinearGraph = (
|
|||||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||||
);
|
);
|
||||||
|
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
if (validIPAdapters.length) {
|
if (validIPAdapters.length) {
|
||||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||||
const ipAdapterCollectNode: CollectInvocation = {
|
const ipAdapterCollectNode: CollectInvocation = {
|
||||||
@ -50,6 +48,7 @@ export const addIPAdapterToLinearGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||||
|
|
||||||
validIPAdapters.forEach((ipAdapter) => {
|
validIPAdapters.forEach((ipAdapter) => {
|
||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
@ -76,19 +75,13 @@ export const addIPAdapterToLinearGraph = (
|
|||||||
|
|
||||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||||
|
|
||||||
if (metadataAccumulator?.ipAdapters) {
|
ipAdapterMetdata.push(
|
||||||
const ipAdapterField = {
|
omit(ipAdapterNode, [
|
||||||
image: {
|
'id',
|
||||||
image_name: ipAdapter.controlImage,
|
'type',
|
||||||
},
|
'is_intermediate',
|
||||||
weight,
|
]) as IPAdapterMetadataField
|
||||||
ip_adapter_model: model,
|
);
|
||||||
begin_step_percent: beginStepPct,
|
|
||||||
end_step_percent: endStepPct,
|
|
||||||
};
|
|
||||||
|
|
||||||
metadataAccumulator.ipAdapters.push(ipAdapterField);
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||||
@ -98,5 +91,7 @@ export const addIPAdapterToLinearGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2,20 +2,20 @@ import { RootState } from 'app/store/store';
|
|||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { forEach, size } from 'lodash-es';
|
import { forEach, size } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
|
CoreMetadataInvocation,
|
||||||
LoraLoaderInvocation,
|
LoraLoaderInvocation,
|
||||||
MetadataAccumulatorInvocation,
|
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
CANVAS_INPAINT_GRAPH,
|
CANVAS_INPAINT_GRAPH,
|
||||||
CANVAS_OUTPAINT_GRAPH,
|
CANVAS_OUTPAINT_GRAPH,
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
|
||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addLoRAsToGraph = (
|
export const addLoRAsToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -33,29 +33,29 @@ export const addLoRAsToGraph = (
|
|||||||
|
|
||||||
const { loras } = state.lora;
|
const { loras } = state.lora;
|
||||||
const loraCount = size(loras);
|
const loraCount = size(loras);
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
if (loraCount > 0) {
|
if (loraCount === 0) {
|
||||||
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
return;
|
||||||
graph.edges = graph.edges.filter(
|
|
||||||
(e) =>
|
|
||||||
!(
|
|
||||||
e.source.node_id === modelLoaderNodeId &&
|
|
||||||
['unet'].includes(e.source.field)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
|
||||||
graph.edges = graph.edges.filter(
|
|
||||||
(e) =>
|
|
||||||
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
|
||||||
|
graph.edges = graph.edges.filter(
|
||||||
|
(e) =>
|
||||||
|
!(
|
||||||
|
e.source.node_id === modelLoaderNodeId &&
|
||||||
|
['unet'].includes(e.source.field)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
// Remove CLIP_SKIP connections to conditionings to feed it through LoRAs
|
||||||
|
graph.edges = graph.edges.filter(
|
||||||
|
(e) =>
|
||||||
|
!(e.source.node_id === CLIP_SKIP && ['clip'].includes(e.source.field))
|
||||||
|
);
|
||||||
|
|
||||||
// we need to remember the last lora so we can chain from it
|
// we need to remember the last lora so we can chain from it
|
||||||
let lastLoraNodeId = '';
|
let lastLoraNodeId = '';
|
||||||
let currentLoraIndex = 0;
|
let currentLoraIndex = 0;
|
||||||
|
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||||
|
|
||||||
forEach(loras, (lora) => {
|
forEach(loras, (lora) => {
|
||||||
const { model_name, base_model, weight } = lora;
|
const { model_name, base_model, weight } = lora;
|
||||||
@ -69,13 +69,10 @@ export const addLoRAsToGraph = (
|
|||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
// add the lora to the metadata accumulator
|
loraMetadata.push({
|
||||||
if (metadataAccumulator?.loras) {
|
lora: { model_name, base_model },
|
||||||
metadataAccumulator.loras.push({
|
weight,
|
||||||
lora: { model_name, base_model },
|
});
|
||||||
weight,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// add to graph
|
// add to graph
|
||||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||||
@ -182,4 +179,6 @@ export const addLoRAsToGraph = (
|
|||||||
lastLoraNodeId = currentLoraNodeId;
|
lastLoraNodeId = currentLoraNodeId;
|
||||||
currentLoraIndex += 1;
|
currentLoraIndex += 1;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
upsertMetadata(graph, { loras: loraMetadata });
|
||||||
};
|
};
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
|
||||||
import { forEach, size } from 'lodash-es';
|
|
||||||
import {
|
import {
|
||||||
MetadataAccumulatorInvocation,
|
LoRAMetadataItem,
|
||||||
SDXLLoraLoaderInvocation,
|
NonNullableGraph,
|
||||||
} from 'services/api/types';
|
zLoRAMetadataItem,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { forEach, size } from 'lodash-es';
|
||||||
|
import { SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
@ -17,6 +17,7 @@ import {
|
|||||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addSDXLLoRAsToGraph = (
|
export const addSDXLLoRAsToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -34,9 +35,12 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
|
|
||||||
const { loras } = state.lora;
|
const { loras } = state.lora;
|
||||||
const loraCount = size(loras);
|
const loraCount = size(loras);
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
if (loraCount === 0) {
|
||||||
| undefined;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const loraMetadata: LoRAMetadataItem[] = [];
|
||||||
|
|
||||||
// Handle Seamless Plugs
|
// Handle Seamless Plugs
|
||||||
const unetLoaderId = modelLoaderNodeId;
|
const unetLoaderId = modelLoaderNodeId;
|
||||||
@ -47,22 +51,17 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
clipLoaderId = SDXL_MODEL_LOADER;
|
clipLoaderId = SDXL_MODEL_LOADER;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (loraCount > 0) {
|
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
|
||||||
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
|
graph.edges = graph.edges.filter(
|
||||||
graph.edges = graph.edges.filter(
|
(e) =>
|
||||||
(e) =>
|
!(
|
||||||
!(
|
e.source.node_id === unetLoaderId && ['unet'].includes(e.source.field)
|
||||||
e.source.node_id === unetLoaderId && ['unet'].includes(e.source.field)
|
) &&
|
||||||
) &&
|
!(
|
||||||
!(
|
e.source.node_id === clipLoaderId && ['clip'].includes(e.source.field)
|
||||||
e.source.node_id === clipLoaderId && ['clip'].includes(e.source.field)
|
) &&
|
||||||
) &&
|
!(e.source.node_id === clipLoaderId && ['clip2'].includes(e.source.field))
|
||||||
!(
|
);
|
||||||
e.source.node_id === clipLoaderId &&
|
|
||||||
['clip2'].includes(e.source.field)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// we need to remember the last lora so we can chain from it
|
// we need to remember the last lora so we can chain from it
|
||||||
let lastLoraNodeId = '';
|
let lastLoraNodeId = '';
|
||||||
@ -80,16 +79,12 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
// add the lora to the metadata accumulator
|
loraMetadata.push(
|
||||||
if (metadataAccumulator) {
|
zLoRAMetadataItem.parse({
|
||||||
if (!metadataAccumulator.loras) {
|
|
||||||
metadataAccumulator.loras = [];
|
|
||||||
}
|
|
||||||
metadataAccumulator.loras.push({
|
|
||||||
lora: { model_name, base_model },
|
lora: { model_name, base_model },
|
||||||
weight,
|
weight,
|
||||||
});
|
})
|
||||||
}
|
);
|
||||||
|
|
||||||
// add to graph
|
// add to graph
|
||||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||||
@ -242,4 +237,6 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
lastLoraNodeId = currentLoraNodeId;
|
lastLoraNodeId = currentLoraNodeId;
|
||||||
currentLoraIndex += 1;
|
currentLoraIndex += 1;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
upsertMetadata(graph, { loras: loraMetadata });
|
||||||
};
|
};
|
||||||
|
@ -2,7 +2,6 @@ import { RootState } from 'app/store/store';
|
|||||||
import {
|
import {
|
||||||
CreateDenoiseMaskInvocation,
|
CreateDenoiseMaskInvocation,
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
MetadataAccumulatorInvocation,
|
|
||||||
SeamlessModeInvocation,
|
SeamlessModeInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../../types/types';
|
import { NonNullableGraph } from '../../types/types';
|
||||||
@ -12,7 +11,6 @@ import {
|
|||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MASK_COMBINE,
|
MASK_COMBINE,
|
||||||
MASK_RESIZE_UP,
|
MASK_RESIZE_UP,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
@ -26,6 +24,7 @@ import {
|
|||||||
SDXL_REFINER_SEAMLESS,
|
SDXL_REFINER_SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addSDXLRefinerToGraph = (
|
export const addSDXLRefinerToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -58,21 +57,15 @@ export const addSDXLRefinerToGraph = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
upsertMetadata(graph, {
|
||||||
| MetadataAccumulatorInvocation
|
refiner_model: refinerModel,
|
||||||
| undefined;
|
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||||
|
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||||
if (metadataAccumulator) {
|
refiner_cfg_scale: refinerCFGScale,
|
||||||
metadataAccumulator.refiner_model = refinerModel;
|
refiner_scheduler: refinerScheduler,
|
||||||
metadataAccumulator.refiner_positive_aesthetic_score =
|
refiner_start: refinerStart,
|
||||||
refinerPositiveAestheticScore;
|
refiner_steps: refinerSteps,
|
||||||
metadataAccumulator.refiner_negative_aesthetic_score =
|
});
|
||||||
refinerNegativeAestheticScore;
|
|
||||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
|
||||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
|
||||||
metadataAccumulator.refiner_start = refinerStart;
|
|
||||||
metadataAccumulator.refiner_steps = refinerSteps;
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelLoaderId = modelLoaderNodeId
|
const modelLoaderId = modelLoaderNodeId
|
||||||
? modelLoaderNodeId
|
? modelLoaderNodeId
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { SaveImageInvocation } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
LATENTS_TO_IMAGE_HRF,
|
LATENTS_TO_IMAGE_HRF,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NSFW_CHECKER,
|
NSFW_CHECKER,
|
||||||
SAVE_IMAGE,
|
SAVE_IMAGE,
|
||||||
WATERMARKER,
|
WATERMARKER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import {
|
|
||||||
MetadataAccumulatorInvocation,
|
|
||||||
SaveImageInvocation,
|
|
||||||
} from 'services/api/types';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
|
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
|
||||||
@ -37,23 +33,6 @@ export const addSaveImageNode = (
|
|||||||
|
|
||||||
graph.nodes[SAVE_IMAGE] = saveImageNode;
|
graph.nodes[SAVE_IMAGE] = saveImageNode;
|
||||||
|
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
if (metadataAccumulator) {
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: SAVE_IMAGE,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const destination = {
|
const destination = {
|
||||||
node_id: SAVE_IMAGE,
|
node_id: SAVE_IMAGE,
|
||||||
field: 'image',
|
field: 'image',
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { SeamlessModeInvocation } from 'services/api/types';
|
import { SeamlessModeInvocation } from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../../types/types';
|
import { NonNullableGraph } from '../../types/types';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
CANVAS_INPAINT_GRAPH,
|
CANVAS_INPAINT_GRAPH,
|
||||||
@ -31,6 +32,17 @@ export const addSeamlessToLinearGraph = (
|
|||||||
seamless_y: seamlessYAxis,
|
seamless_y: seamlessYAxis,
|
||||||
} as SeamlessModeInvocation;
|
} as SeamlessModeInvocation;
|
||||||
|
|
||||||
|
if (seamlessXAxis) {
|
||||||
|
upsertMetadata(graph, {
|
||||||
|
seamless_x: seamlessXAxis,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (seamlessYAxis) {
|
||||||
|
upsertMetadata(graph, {
|
||||||
|
seamless_y: seamlessYAxis,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let denoisingNodeId = DENOISE_LATENTS;
|
let denoisingNodeId = DENOISE_LATENTS;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -3,15 +3,15 @@ import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAd
|
|||||||
import { omit } from 'lodash-es';
|
import { omit } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
MetadataAccumulatorInvocation,
|
CoreMetadataInvocation,
|
||||||
T2IAdapterInvocation,
|
T2IAdapterInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
|
import { NonNullableGraph, T2IAdapterField } from '../../types/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
T2I_ADAPTER_COLLECT,
|
T2I_ADAPTER_COLLECT,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addT2IAdaptersToLinearGraph = (
|
export const addT2IAdaptersToLinearGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -22,10 +22,6 @@ export const addT2IAdaptersToLinearGraph = (
|
|||||||
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
(ca) => ca.model?.base_model === state.generation.model?.base_model
|
||||||
);
|
);
|
||||||
|
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
if (validT2IAdapters.length) {
|
if (validT2IAdapters.length) {
|
||||||
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
// Even though denoise_latents' control input is polymorphic, keep it simple and always use a collect
|
||||||
const t2iAdapterCollectNode: CollectInvocation = {
|
const t2iAdapterCollectNode: CollectInvocation = {
|
||||||
@ -51,6 +47,7 @@ export const addT2IAdaptersToLinearGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||||
|
|
||||||
validT2IAdapters.forEach((t2iAdapter) => {
|
validT2IAdapters.forEach((t2iAdapter) => {
|
||||||
if (!t2iAdapter.model) {
|
if (!t2iAdapter.model) {
|
||||||
@ -96,15 +93,13 @@ export const addT2IAdaptersToLinearGraph = (
|
|||||||
|
|
||||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
||||||
|
|
||||||
if (metadataAccumulator?.t2iAdapters) {
|
t2iAdapterMetdata.push(
|
||||||
// metadata accumulator only needs a control field - not the whole node
|
omit(t2iAdapterNode, [
|
||||||
// extract what we need and add to the accumulator
|
|
||||||
const t2iAdapterField = omit(t2iAdapterNode, [
|
|
||||||
'id',
|
'id',
|
||||||
'type',
|
'type',
|
||||||
]) as T2IAdapterField;
|
'is_intermediate',
|
||||||
metadataAccumulator.t2iAdapters.push(t2iAdapterField);
|
]) as T2IAdapterField
|
||||||
}
|
);
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||||
@ -114,5 +109,7 @@ export const addT2IAdaptersToLinearGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetdata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
@ -14,7 +13,6 @@ import {
|
|||||||
INPAINT_IMAGE,
|
INPAINT_IMAGE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
ONNX_MODEL_LOADER,
|
ONNX_MODEL_LOADER,
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
@ -26,6 +24,7 @@ import {
|
|||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addVAEToGraph = (
|
export const addVAEToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -41,9 +40,6 @@ export const addVAEToGraph = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const isAutoVae = !vae;
|
const isAutoVae = !vae;
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
if (!isAutoVae) {
|
if (!isAutoVae) {
|
||||||
graph.nodes[VAE_LOADER] = {
|
graph.nodes[VAE_LOADER] = {
|
||||||
@ -181,7 +177,7 @@ export const addVAEToGraph = (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (vae && metadataAccumulator) {
|
if (vae) {
|
||||||
metadataAccumulator.vae = vae;
|
upsertMetadata(graph, { vae });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -5,14 +5,8 @@ import {
|
|||||||
ImageNSFWBlurInvocation,
|
ImageNSFWBlurInvocation,
|
||||||
ImageWatermarkInvocation,
|
ImageWatermarkInvocation,
|
||||||
LatentsToImageInvocation,
|
LatentsToImageInvocation,
|
||||||
MetadataAccumulatorInvocation,
|
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import {
|
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
|
||||||
LATENTS_TO_IMAGE,
|
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NSFW_CHECKER,
|
|
||||||
WATERMARKER,
|
|
||||||
} from './constants';
|
|
||||||
|
|
||||||
export const addWatermarkerToGraph = (
|
export const addWatermarkerToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -32,10 +26,6 @@ export const addWatermarkerToGraph = (
|
|||||||
| ImageNSFWBlurInvocation
|
| ImageNSFWBlurInvocation
|
||||||
| undefined;
|
| undefined;
|
||||||
|
|
||||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
|
||||||
| MetadataAccumulatorInvocation
|
|
||||||
| undefined;
|
|
||||||
|
|
||||||
if (!nodeToAddTo) {
|
if (!nodeToAddTo) {
|
||||||
// something has gone terribly awry
|
// something has gone terribly awry
|
||||||
return;
|
return;
|
||||||
@ -80,17 +70,4 @@ export const addWatermarkerToGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (metadataAccumulator) {
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: WATERMARKER,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
|
import { BoardId } from 'features/gallery/store/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
||||||
import {
|
import {
|
||||||
Graph,
|
|
||||||
ESRGANInvocation,
|
ESRGANInvocation,
|
||||||
|
Graph,
|
||||||
SaveImageInvocation,
|
SaveImageInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { REALESRGAN as ESRGAN, SAVE_IMAGE } from './constants';
|
import { REALESRGAN as ESRGAN, SAVE_IMAGE } from './constants';
|
||||||
import { BoardId } from 'features/gallery/store/types';
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
type Arg = {
|
type Arg = {
|
||||||
image_name: string;
|
image_name: string;
|
||||||
@ -55,5 +56,9 @@ export const buildAdHocUpscaleGraph = ({
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
addCoreMetadataNode(graph, {
|
||||||
|
esrgan_model: esrganModelName,
|
||||||
|
});
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
};
|
};
|
||||||
|
@ -20,12 +20,12 @@ import {
|
|||||||
IMG2IMG_RESIZE,
|
IMG2IMG_RESIZE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Image to Image graph.
|
* Builds the Canvas tab's Image to Image graph.
|
||||||
@ -308,10 +308,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||||
@ -325,15 +322,10 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined, // option; set in addVAEToGraph
|
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
|
||||||
loras: [], // populated in addLoRAsToGraph
|
|
||||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
|
||||||
t2iAdapters: [],
|
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
strength,
|
strength,
|
||||||
init_image: initialImage.image_name,
|
init_image: initialImage.image_name,
|
||||||
};
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
@ -16,7 +16,6 @@ import {
|
|||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
IMG2IMG_RESIZE,
|
IMG2IMG_RESIZE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
@ -28,6 +27,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Image to Image graph.
|
* Builds the Canvas tab's Image to Image graph.
|
||||||
@ -319,10 +319,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||||
@ -336,24 +333,8 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined, // option; set in addVAEToGraph
|
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
|
||||||
loras: [], // populated in addLoRAsToGraph
|
|
||||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
|
||||||
t2iAdapters: [],
|
|
||||||
strength,
|
strength,
|
||||||
init_image: initialImage.image_name,
|
init_image: initialImage.image_name,
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: CANVAS_OUTPUT,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -18,7 +18,6 @@ import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
|||||||
import {
|
import {
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
ONNX_MODEL_LOADER,
|
ONNX_MODEL_LOADER,
|
||||||
@ -30,6 +29,7 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -301,10 +301,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||||
@ -318,22 +315,6 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined, // option; set in addVAEToGraph
|
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
|
||||||
loras: [], // populated in addLoRAsToGraph
|
|
||||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
|
||||||
t2iAdapters: [],
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: CANVAS_OUTPUT,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -21,13 +21,13 @@ import {
|
|||||||
DENOISE_LATENTS,
|
DENOISE_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
ONNX_MODEL_LOADER,
|
ONNX_MODEL_LOADER,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -289,10 +289,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
width: !isUsingScaledDimensions ? width : scaledBoundingBoxDimensions.width,
|
||||||
@ -306,23 +303,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined, // option; set in addVAEToGraph
|
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
|
||||||
loras: [], // populated in addLoRAsToGraph
|
|
||||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
|
||||||
t2iAdapters: [],
|
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: CANVAS_OUTPUT,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -2,15 +2,16 @@ import { NUMPY_RAND_MAX } from 'app/constants';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { generateSeeds } from 'common/util/generateSeeds';
|
import { generateSeeds } from 'common/util/generateSeeds';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { range, unset } from 'lodash-es';
|
import { range } from 'lodash-es';
|
||||||
import { components } from 'services/api/schema';
|
import { components } from 'services/api/schema';
|
||||||
import { Batch, BatchConfig } from 'services/api/types';
|
import { Batch, BatchConfig } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_COHERENCE_NOISE,
|
CANVAS_COHERENCE_NOISE,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { getHasMetadata, removeMetadata } from './metadata';
|
||||||
|
|
||||||
export const prepareLinearUIBatch = (
|
export const prepareLinearUIBatch = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -24,7 +25,6 @@ export const prepareLinearUIBatch = (
|
|||||||
const data: Batch['data'] = [];
|
const data: Batch['data'] = [];
|
||||||
|
|
||||||
if (prompts.length === 1) {
|
if (prompts.length === 1) {
|
||||||
unset(graph.nodes[METADATA_ACCUMULATOR], 'seed');
|
|
||||||
const seeds = generateSeeds({
|
const seeds = generateSeeds({
|
||||||
count: iterations,
|
count: iterations,
|
||||||
start: shouldRandomizeSeed ? undefined : seed,
|
start: shouldRandomizeSeed ? undefined : seed,
|
||||||
@ -40,9 +40,11 @@ export const prepareLinearUIBatch = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
if (getHasMetadata(graph)) {
|
||||||
|
// add to metadata
|
||||||
|
removeMetadata(graph, 'seed');
|
||||||
zipped.push({
|
zipped.push({
|
||||||
node_path: METADATA_ACCUMULATOR,
|
node_path: METADATA,
|
||||||
field_name: 'seed',
|
field_name: 'seed',
|
||||||
items: seeds,
|
items: seeds,
|
||||||
});
|
});
|
||||||
@ -77,9 +79,11 @@ export const prepareLinearUIBatch = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
// add to metadata
|
||||||
|
if (getHasMetadata(graph)) {
|
||||||
|
removeMetadata(graph, 'seed');
|
||||||
firstBatchDatumList.push({
|
firstBatchDatumList.push({
|
||||||
node_path: METADATA_ACCUMULATOR,
|
node_path: METADATA,
|
||||||
field_name: 'seed',
|
field_name: 'seed',
|
||||||
items: seeds,
|
items: seeds,
|
||||||
});
|
});
|
||||||
@ -106,13 +110,17 @@ export const prepareLinearUIBatch = (
|
|||||||
items: seeds,
|
items: seeds,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
|
||||||
|
// add to metadata
|
||||||
|
if (getHasMetadata(graph)) {
|
||||||
|
removeMetadata(graph, 'seed');
|
||||||
secondBatchDatumList.push({
|
secondBatchDatumList.push({
|
||||||
node_path: METADATA_ACCUMULATOR,
|
node_path: METADATA,
|
||||||
field_name: 'seed',
|
field_name: 'seed',
|
||||||
items: seeds,
|
items: seeds,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
if (graph.nodes[CANVAS_COHERENCE_NOISE]) {
|
||||||
secondBatchDatumList.push({
|
secondBatchDatumList.push({
|
||||||
node_path: CANVAS_COHERENCE_NOISE,
|
node_path: CANVAS_COHERENCE_NOISE,
|
||||||
@ -137,17 +145,17 @@ export const prepareLinearUIBatch = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
// add to metadata
|
||||||
|
if (getHasMetadata(graph)) {
|
||||||
|
removeMetadata(graph, 'positive_prompt');
|
||||||
firstBatchDatumList.push({
|
firstBatchDatumList.push({
|
||||||
node_path: METADATA_ACCUMULATOR,
|
node_path: METADATA,
|
||||||
field_name: 'positive_prompt',
|
field_name: 'positive_prompt',
|
||||||
items: extendedPrompts,
|
items: extendedPrompts,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
|
if (shouldConcatSDXLStylePrompt && model?.base_model === 'sdxl') {
|
||||||
unset(graph.nodes[METADATA_ACCUMULATOR], 'positive_style_prompt');
|
|
||||||
|
|
||||||
const stylePrompts = extendedPrompts.map((p) =>
|
const stylePrompts = extendedPrompts.map((p) =>
|
||||||
[p, positiveStylePrompt].join(' ')
|
[p, positiveStylePrompt].join(' ')
|
||||||
);
|
);
|
||||||
@ -160,11 +168,13 @@ export const prepareLinearUIBatch = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (graph.nodes[METADATA_ACCUMULATOR]) {
|
// add to metadata
|
||||||
|
if (getHasMetadata(graph)) {
|
||||||
|
removeMetadata(graph, 'positive_style_prompt');
|
||||||
firstBatchDatumList.push({
|
firstBatchDatumList.push({
|
||||||
node_path: METADATA_ACCUMULATOR,
|
node_path: METADATA,
|
||||||
field_name: 'positive_style_prompt',
|
field_name: 'positive_style_prompt',
|
||||||
items: stylePrompts,
|
items: extendedPrompts,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,13 +21,13 @@ import {
|
|||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RESIZE,
|
RESIZE,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Image to Image tab graph.
|
* Builds the Image to Image tab graph.
|
||||||
@ -311,10 +311,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
height,
|
height,
|
||||||
@ -326,25 +323,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined, // option; set in addVAEToGraph
|
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
|
||||||
loras: [], // populated in addLoRAsToGraph
|
|
||||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
|
||||||
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
|
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
strength,
|
strength,
|
||||||
init_image: initialImage.imageName,
|
init_image: initialImage.imageName,
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -18,7 +18,6 @@ import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
|||||||
import {
|
import {
|
||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
@ -30,6 +29,7 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Image to Image tab graph.
|
* Builds the Image to Image tab graph.
|
||||||
@ -331,10 +331,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'sdxl_img2img',
|
generation_mode: 'sdxl_img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
height,
|
height,
|
||||||
@ -346,26 +343,10 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined,
|
strength,
|
||||||
controlnets: [],
|
|
||||||
loras: [],
|
|
||||||
ipAdapters: [],
|
|
||||||
t2iAdapters: [],
|
|
||||||
strength: strength,
|
|
||||||
init_image: initialImage.imageName,
|
init_image: initialImage.imageName,
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
negative_style_prompt: negativeStylePrompt,
|
negative_style_prompt: negativeStylePrompt,
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -11,9 +11,9 @@ import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
|||||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||||
import { addVAEToGraph } from './addVAEToGraph';
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
import {
|
import {
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
@ -225,10 +225,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'sdxl_txt2img',
|
generation_mode: 'sdxl_txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
height,
|
height,
|
||||||
@ -240,24 +237,8 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined,
|
|
||||||
controlnets: [],
|
|
||||||
loras: [],
|
|
||||||
ipAdapters: [],
|
|
||||||
t2iAdapters: [],
|
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
negative_style_prompt: negativeStylePrompt,
|
negative_style_prompt: negativeStylePrompt,
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -15,12 +15,12 @@ import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
|||||||
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
|
||||||
import { addVAEToGraph } from './addVAEToGraph';
|
import { addVAEToGraph } from './addVAEToGraph';
|
||||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||||
|
import { addCoreMetadataNode } from './metadata';
|
||||||
import {
|
import {
|
||||||
CLIP_SKIP,
|
CLIP_SKIP,
|
||||||
DENOISE_LATENTS,
|
DENOISE_LATENTS,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MAIN_MODEL_LOADER,
|
MAIN_MODEL_LOADER,
|
||||||
METADATA_ACCUMULATOR,
|
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
ONNX_MODEL_LOADER,
|
ONNX_MODEL_LOADER,
|
||||||
@ -48,10 +48,6 @@ export const buildLinearTextToImageGraph = (
|
|||||||
seamlessXAxis,
|
seamlessXAxis,
|
||||||
seamlessYAxis,
|
seamlessYAxis,
|
||||||
seed,
|
seed,
|
||||||
hrfWidth,
|
|
||||||
hrfHeight,
|
|
||||||
hrfStrength,
|
|
||||||
hrfEnabled: hrfEnabled,
|
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
const use_cpu = shouldUseCpuNoise;
|
const use_cpu = shouldUseCpuNoise;
|
||||||
@ -238,10 +234,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
addCoreMetadataNode(graph, {
|
||||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
|
||||||
id: METADATA_ACCUMULATOR,
|
|
||||||
type: 'metadata_accumulator',
|
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
height,
|
height,
|
||||||
@ -253,26 +246,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
scheduler,
|
scheduler,
|
||||||
vae: undefined, // option; set in addVAEToGraph
|
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
|
||||||
loras: [], // populated in addLoRAsToGraph
|
|
||||||
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
|
||||||
t2iAdapters: [], // populated in addT2IAdapterToLinearGraph
|
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
hrf_width: hrfEnabled ? hrfWidth : undefined,
|
|
||||||
hrf_height: hrfEnabled ? hrfHeight : undefined,
|
|
||||||
hrf_strength: hrfEnabled ? hrfStrength : undefined,
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: METADATA_ACCUMULATOR,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: LATENTS_TO_IMAGE,
|
|
||||||
field: 'metadata',
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add Seamless To Graph
|
// Add Seamless To Graph
|
||||||
|
@ -35,7 +35,6 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
|||||||
const { nodes, edges } = nodesState;
|
const { nodes, edges } = nodesState;
|
||||||
|
|
||||||
const filteredNodes = nodes.filter(isInvocationNode);
|
const filteredNodes = nodes.filter(isInvocationNode);
|
||||||
const workflowJSON = JSON.stringify(buildWorkflow(nodesState));
|
|
||||||
|
|
||||||
// Reduce the node editor nodes into invocation graph nodes
|
// Reduce the node editor nodes into invocation graph nodes
|
||||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
|
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
|
||||||
@ -68,7 +67,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
|||||||
|
|
||||||
if (embedWorkflow) {
|
if (embedWorkflow) {
|
||||||
// add the workflow to the node
|
// add the workflow to the node
|
||||||
Object.assign(graphNode, { workflow: workflowJSON });
|
Object.assign(graphNode, { workflow: buildWorkflow(nodesState) });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add it to the nodes object
|
// Add it to the nodes object
|
||||||
|
@ -56,7 +56,14 @@ export const IP_ADAPTER = 'ip_adapter';
|
|||||||
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||||
export const IMAGE_COLLECTION = 'image_collection';
|
export const IMAGE_COLLECTION = 'image_collection';
|
||||||
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
|
export const IMAGE_COLLECTION_ITERATE = 'image_collection_iterate';
|
||||||
export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
export const METADATA = 'core_metadata';
|
||||||
|
export const BATCH_METADATA = 'batch_metadata';
|
||||||
|
export const BATCH_METADATA_COLLECT = 'batch_metadata_collect';
|
||||||
|
export const BATCH_SEED = 'batch_seed';
|
||||||
|
export const BATCH_PROMPT = 'batch_prompt';
|
||||||
|
export const BATCH_STYLE_PROMPT = 'batch_style_prompt';
|
||||||
|
export const METADATA_COLLECT = 'metadata_collect';
|
||||||
|
export const MERGE_METADATA = 'merge_metadata';
|
||||||
export const REALESRGAN = 'esrgan';
|
export const REALESRGAN = 'esrgan';
|
||||||
export const DIVIDE = 'divide';
|
export const DIVIDE = 'divide';
|
||||||
export const SCALE = 'scale_image';
|
export const SCALE = 'scale_image';
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { CoreMetadataInvocation } from 'services/api/types';
|
||||||
|
import { JsonObject } from 'type-fest';
|
||||||
|
import { METADATA, SAVE_IMAGE } from './constants';
|
||||||
|
|
||||||
|
export const addCoreMetadataNode = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
metadata: Partial<CoreMetadataInvocation> | JsonObject
|
||||||
|
): void => {
|
||||||
|
graph.nodes[METADATA] = {
|
||||||
|
id: METADATA,
|
||||||
|
type: 'core_metadata',
|
||||||
|
...metadata,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: METADATA,
|
||||||
|
field: 'metadata',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SAVE_IMAGE,
|
||||||
|
field: 'metadata',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const upsertMetadata = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
metadata: Partial<CoreMetadataInvocation> | JsonObject
|
||||||
|
): void => {
|
||||||
|
const metadataNode = graph.nodes[METADATA] as
|
||||||
|
| CoreMetadataInvocation
|
||||||
|
| undefined;
|
||||||
|
|
||||||
|
if (!metadataNode) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.assign(metadataNode, metadata);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const removeMetadata = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
key: keyof CoreMetadataInvocation
|
||||||
|
): void => {
|
||||||
|
const metadataNode = graph.nodes[METADATA] as
|
||||||
|
| CoreMetadataInvocation
|
||||||
|
| undefined;
|
||||||
|
|
||||||
|
if (!metadataNode) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete metadataNode[key];
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getHasMetadata = (graph: NonNullableGraph): boolean => {
|
||||||
|
const metadataNode = graph.nodes[METADATA] as
|
||||||
|
| CoreMetadataInvocation
|
||||||
|
| undefined;
|
||||||
|
|
||||||
|
return Boolean(metadataNode);
|
||||||
|
};
|
@ -4,7 +4,6 @@ import { reduce, startCase } from 'lodash-es';
|
|||||||
import { OpenAPIV3_1 } from 'openapi-types';
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
import {
|
import {
|
||||||
FieldType,
|
|
||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
InvocationSchemaObject,
|
InvocationSchemaObject,
|
||||||
InvocationTemplate,
|
InvocationTemplate,
|
||||||
@ -16,18 +15,11 @@ import {
|
|||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||||
|
|
||||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata', 'use_cache'];
|
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache'];
|
||||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||||
const RESERVED_FIELD_TYPES = [
|
const RESERVED_FIELD_TYPES = ['IsIntermediate'];
|
||||||
'WorkflowField',
|
|
||||||
'MetadataField',
|
|
||||||
'IsIntermediate',
|
|
||||||
];
|
|
||||||
|
|
||||||
const invocationDenylist: AnyInvocationType[] = [
|
const invocationDenylist: AnyInvocationType[] = ['graph'];
|
||||||
'graph',
|
|
||||||
'metadata_accumulator',
|
|
||||||
];
|
|
||||||
|
|
||||||
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||||
@ -42,7 +34,7 @@ const isReservedInputField = (nodeType: string, fieldName: string) => {
|
|||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
const isReservedFieldType = (fieldType: FieldType) => {
|
const isReservedFieldType = (fieldType: string) => {
|
||||||
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
|
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -86,6 +78,7 @@ export const parseSchema = (
|
|||||||
const tags = schema.tags ?? [];
|
const tags = schema.tags ?? [];
|
||||||
const description = schema.description ?? '';
|
const description = schema.description ?? '';
|
||||||
const version = schema.version;
|
const version = schema.version;
|
||||||
|
let withWorkflow = false;
|
||||||
|
|
||||||
const inputs = reduce(
|
const inputs = reduce(
|
||||||
schema.properties,
|
schema.properties,
|
||||||
@ -112,7 +105,7 @@ export const parseSchema = (
|
|||||||
|
|
||||||
const fieldType = property.ui_type ?? getFieldType(property);
|
const fieldType = property.ui_type ?? getFieldType(property);
|
||||||
|
|
||||||
if (!isFieldType(fieldType)) {
|
if (!fieldType) {
|
||||||
logger('nodes').warn(
|
logger('nodes').warn(
|
||||||
{
|
{
|
||||||
node: type,
|
node: type,
|
||||||
@ -120,11 +113,16 @@ export const parseSchema = (
|
|||||||
fieldType,
|
fieldType,
|
||||||
field: parseify(property),
|
field: parseify(property),
|
||||||
},
|
},
|
||||||
'Skipping unknown input field type'
|
'Missing input field type'
|
||||||
);
|
);
|
||||||
return inputsAccumulator;
|
return inputsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (fieldType === 'WorkflowField') {
|
||||||
|
withWorkflow = true;
|
||||||
|
return inputsAccumulator;
|
||||||
|
}
|
||||||
|
|
||||||
if (isReservedFieldType(fieldType)) {
|
if (isReservedFieldType(fieldType)) {
|
||||||
logger('nodes').trace(
|
logger('nodes').trace(
|
||||||
{
|
{
|
||||||
@ -133,7 +131,20 @@ export const parseSchema = (
|
|||||||
fieldType,
|
fieldType,
|
||||||
field: parseify(property),
|
field: parseify(property),
|
||||||
},
|
},
|
||||||
'Skipping reserved field type'
|
`Skipping reserved input field type: ${fieldType}`
|
||||||
|
);
|
||||||
|
return inputsAccumulator;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isFieldType(fieldType)) {
|
||||||
|
logger('nodes').warn(
|
||||||
|
{
|
||||||
|
node: type,
|
||||||
|
fieldName: propertyName,
|
||||||
|
fieldType,
|
||||||
|
field: parseify(property),
|
||||||
|
},
|
||||||
|
`Skipping unknown input field type: ${fieldType}`
|
||||||
);
|
);
|
||||||
return inputsAccumulator;
|
return inputsAccumulator;
|
||||||
}
|
}
|
||||||
@ -146,7 +157,7 @@ export const parseSchema = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (!field) {
|
if (!field) {
|
||||||
logger('nodes').debug(
|
logger('nodes').warn(
|
||||||
{
|
{
|
||||||
node: type,
|
node: type,
|
||||||
fieldName: propertyName,
|
fieldName: propertyName,
|
||||||
@ -248,6 +259,7 @@ export const parseSchema = (
|
|||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
useCache,
|
useCache,
|
||||||
|
withWorkflow,
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
@ -1,20 +1,15 @@
|
|||||||
import { EntityState, Update } from '@reduxjs/toolkit';
|
import { EntityState, Update } from '@reduxjs/toolkit';
|
||||||
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
|
||||||
import { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
|
import { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
import {
|
import {
|
||||||
ASSETS_CATEGORIES,
|
ASSETS_CATEGORIES,
|
||||||
BoardId,
|
BoardId,
|
||||||
IMAGE_CATEGORIES,
|
IMAGE_CATEGORIES,
|
||||||
IMAGE_LIMIT,
|
IMAGE_LIMIT,
|
||||||
} from 'features/gallery/store/types';
|
} from 'features/gallery/store/types';
|
||||||
import {
|
import { CoreMetadata, zCoreMetadata } from 'features/nodes/types/types';
|
||||||
ImageMetadataAndWorkflow,
|
|
||||||
zCoreMetadata,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
|
|
||||||
import { keyBy } from 'lodash-es';
|
import { keyBy } from 'lodash-es';
|
||||||
import { ApiTagDescription, LIST_TAG, api } from '..';
|
import { ApiTagDescription, LIST_TAG, api } from '..';
|
||||||
import { $authToken, $projectId } from '../client';
|
|
||||||
import { components, paths } from '../schema';
|
import { components, paths } from '../schema';
|
||||||
import {
|
import {
|
||||||
DeleteBoardResult,
|
DeleteBoardResult,
|
||||||
@ -23,7 +18,6 @@ import {
|
|||||||
ListImagesArgs,
|
ListImagesArgs,
|
||||||
OffsetPaginatedResults_ImageDTO_,
|
OffsetPaginatedResults_ImageDTO_,
|
||||||
PostUploadAction,
|
PostUploadAction,
|
||||||
UnsafeImageMetadata,
|
|
||||||
} from '../types';
|
} from '../types';
|
||||||
import {
|
import {
|
||||||
getCategories,
|
getCategories,
|
||||||
@ -114,73 +108,24 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
],
|
],
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
getImageMetadata: build.query<UnsafeImageMetadata, string>({
|
getImageMetadata: build.query<CoreMetadata | undefined, string>({
|
||||||
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
|
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
|
||||||
providesTags: (result, error, image_name) => [
|
providesTags: (result, error, image_name) => [
|
||||||
{ type: 'ImageMetadata', id: image_name },
|
{ type: 'ImageMetadata', id: image_name },
|
||||||
],
|
],
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
transformResponse: (
|
||||||
}),
|
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
|
||||||
getImageMetadataFromFile: build.query<
|
|
||||||
ImageMetadataAndWorkflow,
|
|
||||||
{ image: ImageDTO; shouldFetchMetadataFromApi: boolean }
|
|
||||||
>({
|
|
||||||
queryFn: async (
|
|
||||||
args: { image: ImageDTO; shouldFetchMetadataFromApi: boolean },
|
|
||||||
api,
|
|
||||||
extraOptions,
|
|
||||||
fetchWithBaseQuery
|
|
||||||
) => {
|
) => {
|
||||||
if (args.shouldFetchMetadataFromApi) {
|
if (response) {
|
||||||
let metadata;
|
const result = zCoreMetadata.safeParse(response);
|
||||||
const metadataResponse = await fetchWithBaseQuery(
|
if (result.success) {
|
||||||
`images/i/${args.image.image_name}/metadata`
|
return result.data;
|
||||||
);
|
} else {
|
||||||
if (metadataResponse.data) {
|
logger('images').warn('Problem parsing metadata');
|
||||||
const metadataResult = zCoreMetadata.safeParse(
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
(metadataResponse.data as any)?.metadata
|
|
||||||
);
|
|
||||||
if (metadataResult.success) {
|
|
||||||
metadata = metadataResult.data;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return { data: { metadata } };
|
|
||||||
} else {
|
|
||||||
const authToken = $authToken.get();
|
|
||||||
const projectId = $projectId.get();
|
|
||||||
const customBaseQuery = fetchBaseQuery({
|
|
||||||
baseUrl: '',
|
|
||||||
prepareHeaders: (headers) => {
|
|
||||||
if (authToken) {
|
|
||||||
headers.set('Authorization', `Bearer ${authToken}`);
|
|
||||||
}
|
|
||||||
if (projectId) {
|
|
||||||
headers.set('project-id', projectId);
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers;
|
|
||||||
},
|
|
||||||
responseHandler: async (res) => {
|
|
||||||
return await res.blob();
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await customBaseQuery(
|
|
||||||
args.image.image_url,
|
|
||||||
api,
|
|
||||||
extraOptions
|
|
||||||
);
|
|
||||||
const data = await getMetadataAndWorkflowFromImageBlob(
|
|
||||||
response.data as Blob
|
|
||||||
);
|
|
||||||
|
|
||||||
return { data };
|
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
},
|
},
|
||||||
providesTags: (result, error, { image }) => [
|
|
||||||
{ type: 'ImageMetadataFromFile', id: image.image_name },
|
|
||||||
],
|
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
deleteImage: build.mutation<void, ImageDTO>({
|
deleteImage: build.mutation<void, ImageDTO>({
|
||||||
@ -1629,6 +1574,5 @@ export const {
|
|||||||
useDeleteBoardMutation,
|
useDeleteBoardMutation,
|
||||||
useStarImagesMutation,
|
useStarImagesMutation,
|
||||||
useUnstarImagesMutation,
|
useUnstarImagesMutation,
|
||||||
useGetImageMetadataFromFileQuery,
|
|
||||||
useBulkDownloadImagesMutation,
|
useBulkDownloadImagesMutation,
|
||||||
} = imagesApi;
|
} = imagesApi;
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { Workflow, zWorkflow } from 'features/nodes/types/types';
|
||||||
|
import { api } from '..';
|
||||||
|
import { paths } from '../schema';
|
||||||
|
|
||||||
|
export const workflowsApi = api.injectEndpoints({
|
||||||
|
endpoints: (build) => ({
|
||||||
|
getWorkflow: build.query<Workflow | undefined, string>({
|
||||||
|
query: (workflow_id) => `workflows/i/${workflow_id}`,
|
||||||
|
providesTags: (result, error, workflow_id) => [
|
||||||
|
{ type: 'Workflow', id: workflow_id },
|
||||||
|
],
|
||||||
|
transformResponse: (
|
||||||
|
response: paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json']
|
||||||
|
) => {
|
||||||
|
if (response) {
|
||||||
|
const result = zWorkflow.safeParse(response);
|
||||||
|
if (result.success) {
|
||||||
|
return result.data;
|
||||||
|
} else {
|
||||||
|
logger('images').warn('Problem parsing workflow');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { useGetWorkflowQuery } = workflowsApi;
|
@ -0,0 +1,21 @@
|
|||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { useDebounce } from 'use-debounce';
|
||||||
|
import { useGetImageMetadataQuery } from '../endpoints/images';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
|
export const useDebouncedMetadata = (imageName?: string | null) => {
|
||||||
|
const metadataFetchDebounce = useAppSelector(
|
||||||
|
(state) => state.config.metadataFetchDebounce
|
||||||
|
);
|
||||||
|
|
||||||
|
const [debouncedImageName] = useDebounce(
|
||||||
|
imageName,
|
||||||
|
metadataFetchDebounce ?? 0
|
||||||
|
);
|
||||||
|
|
||||||
|
const { data: metadata, isLoading } = useGetImageMetadataQuery(
|
||||||
|
debouncedImageName ?? skipToken
|
||||||
|
);
|
||||||
|
|
||||||
|
return { metadata, isLoading };
|
||||||
|
};
|
@ -0,0 +1,21 @@
|
|||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useDebounce } from 'use-debounce';
|
||||||
|
import { useGetWorkflowQuery } from '../endpoints/workflows';
|
||||||
|
|
||||||
|
export const useDebouncedWorkflow = (workflowId?: string | null) => {
|
||||||
|
const workflowFetchDebounce = useAppSelector(
|
||||||
|
(state) => state.config.workflowFetchDebounce
|
||||||
|
);
|
||||||
|
|
||||||
|
const [debouncedWorkflowID] = useDebounce(
|
||||||
|
workflowId,
|
||||||
|
workflowFetchDebounce ?? 0
|
||||||
|
);
|
||||||
|
|
||||||
|
const { data: workflow, isLoading } = useGetWorkflowQuery(
|
||||||
|
debouncedWorkflowID ?? skipToken
|
||||||
|
);
|
||||||
|
|
||||||
|
return { workflow, isLoading };
|
||||||
|
};
|
@ -37,6 +37,7 @@ export const tagTypes = [
|
|||||||
'ControlNetModel',
|
'ControlNetModel',
|
||||||
'LoRAModel',
|
'LoRAModel',
|
||||||
'SDXLRefinerModel',
|
'SDXLRefinerModel',
|
||||||
|
'Workflow',
|
||||||
] as const;
|
] as const;
|
||||||
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
|
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
|
||||||
export const LIST_TAG = 'LIST';
|
export const LIST_TAG = 'LIST';
|
||||||
|
1731
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
1731
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -27,14 +27,6 @@ export type BatchConfig =
|
|||||||
|
|
||||||
export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult'];
|
export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult'];
|
||||||
|
|
||||||
/**
|
|
||||||
* This is an unsafe type; the object inside is not guaranteed to be valid.
|
|
||||||
*/
|
|
||||||
export type UnsafeImageMetadata = {
|
|
||||||
metadata: s['CoreMetadata'];
|
|
||||||
graph: NonNullable<s['Graph']>;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type _InputField = s['_InputField'];
|
export type _InputField = s['_InputField'];
|
||||||
export type _OutputField = s['_OutputField'];
|
export type _OutputField = s['_OutputField'];
|
||||||
|
|
||||||
@ -50,7 +42,6 @@ export type ImageChanges = s['ImageRecordChanges'];
|
|||||||
export type ImageCategory = s['ImageCategory'];
|
export type ImageCategory = s['ImageCategory'];
|
||||||
export type ResourceOrigin = s['ResourceOrigin'];
|
export type ResourceOrigin = s['ResourceOrigin'];
|
||||||
export type ImageField = s['ImageField'];
|
export type ImageField = s['ImageField'];
|
||||||
export type ImageMetadata = s['ImageMetadata'];
|
|
||||||
export type OffsetPaginatedResults_BoardDTO_ =
|
export type OffsetPaginatedResults_BoardDTO_ =
|
||||||
s['OffsetPaginatedResults_BoardDTO_'];
|
s['OffsetPaginatedResults_BoardDTO_'];
|
||||||
export type OffsetPaginatedResults_ImageDTO_ =
|
export type OffsetPaginatedResults_ImageDTO_ =
|
||||||
@ -145,13 +136,19 @@ export type ImageCollectionInvocation = s['ImageCollectionInvocation'];
|
|||||||
export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
|
export type MainModelLoaderInvocation = s['MainModelLoaderInvocation'];
|
||||||
export type OnnxModelLoaderInvocation = s['OnnxModelLoaderInvocation'];
|
export type OnnxModelLoaderInvocation = s['OnnxModelLoaderInvocation'];
|
||||||
export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
|
export type LoraLoaderInvocation = s['LoraLoaderInvocation'];
|
||||||
export type MetadataAccumulatorInvocation = s['MetadataAccumulatorInvocation'];
|
|
||||||
export type ESRGANInvocation = s['ESRGANInvocation'];
|
export type ESRGANInvocation = s['ESRGANInvocation'];
|
||||||
export type DivideInvocation = s['DivideInvocation'];
|
export type DivideInvocation = s['DivideInvocation'];
|
||||||
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
|
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
|
||||||
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
|
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
|
||||||
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
|
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
|
||||||
export type SaveImageInvocation = s['SaveImageInvocation'];
|
export type SaveImageInvocation = s['SaveImageInvocation'];
|
||||||
|
export type MetadataInvocation = s['MetadataInvocation'];
|
||||||
|
export type CoreMetadataInvocation = s['CoreMetadataInvocation'];
|
||||||
|
export type MetadataItemInvocation = s['MetadataItemInvocation'];
|
||||||
|
export type MergeMetadataInvocation = s['MergeMetadataInvocation'];
|
||||||
|
export type IPAdapterMetadataField = s['IPAdapterMetadataField'];
|
||||||
|
export type T2IAdapterField = s['T2IAdapterField'];
|
||||||
|
export type LoRAMetadataField = s['LoRAMetadataField'];
|
||||||
|
|
||||||
// ControlNet Nodes
|
// ControlNet Nodes
|
||||||
export type ControlNetInvocation = s['ControlNetInvocation'];
|
export type ControlNetInvocation = s['ControlNetInvocation'];
|
||||||
|
@ -75,6 +75,8 @@ def mock_services() -> InvocationServices:
|
|||||||
session_processor=None, # type: ignore
|
session_processor=None, # type: ignore
|
||||||
session_queue=None, # type: ignore
|
session_queue=None, # type: ignore
|
||||||
urls=None, # type: ignore
|
urls=None, # type: ignore
|
||||||
|
workflow_records=None, # type: ignore
|
||||||
|
workflow_image_records=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,6 +80,8 @@ def mock_services() -> InvocationServices:
|
|||||||
session_processor=None, # type: ignore
|
session_processor=None, # type: ignore
|
||||||
session_queue=None, # type: ignore
|
session_queue=None, # type: ignore
|
||||||
urls=None, # type: ignore
|
urls=None, # type: ignore
|
||||||
|
workflow_records=None, # type: ignore
|
||||||
|
workflow_image_records=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,12 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.image import ShowImageInvocation
|
from invokeai.app.invocations.image import ShowImageInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
from invokeai.app.invocations.primitives import (
|
||||||
|
FloatCollectionInvocation,
|
||||||
|
FloatInvocation,
|
||||||
|
IntegerInvocation,
|
||||||
|
StringInvocation,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||||
from invokeai.app.services.shared.graph import (
|
from invokeai.app.services.shared.graph import (
|
||||||
@ -27,8 +32,11 @@ from invokeai.app.services.shared.graph import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .test_nodes import (
|
from .test_nodes import (
|
||||||
|
AnyTypeTestInvocation,
|
||||||
ImageToImageTestInvocation,
|
ImageToImageTestInvocation,
|
||||||
ListPassThroughInvocation,
|
ListPassThroughInvocation,
|
||||||
|
PolymorphicStringTestInvocation,
|
||||||
|
PromptCollectionTestInvocation,
|
||||||
PromptTestInvocation,
|
PromptTestInvocation,
|
||||||
TextToImageTestInvocation,
|
TextToImageTestInvocation,
|
||||||
)
|
)
|
||||||
@ -607,8 +615,8 @@ def test_graph_can_deserialize():
|
|||||||
g.add_edge(e)
|
g.add_edge(e)
|
||||||
|
|
||||||
json = g.model_dump_json()
|
json = g.model_dump_json()
|
||||||
adapter_graph = TypeAdapter(Graph)
|
GraphValidator = TypeAdapter(Graph)
|
||||||
g2 = adapter_graph.validate_json(json)
|
g2 = GraphValidator.validate_json(json)
|
||||||
|
|
||||||
assert g2 is not None
|
assert g2 is not None
|
||||||
assert g2.nodes["1"] is not None
|
assert g2.nodes["1"] is not None
|
||||||
@ -692,6 +700,144 @@ def test_ints_do_not_accept_floats():
|
|||||||
g.add_edge(e)
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_accepts_single():
|
||||||
|
g = Graph()
|
||||||
|
n1 = StringInvocation(id="1", value="banana")
|
||||||
|
n2 = PolymorphicStringTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e1 = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_accepts_collection_of_same_base_type():
|
||||||
|
g = Graph()
|
||||||
|
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||||
|
n2 = PolymorphicStringTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_does_not_accept_collection_of_different_base_type():
|
||||||
|
g = Graph()
|
||||||
|
n1 = FloatCollectionInvocation(id="1", collection=[1.0, 2.0, 3.0])
|
||||||
|
n2 = PolymorphicStringTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e1 = create_edge(n1.id, "collection", n2.id, "value")
|
||||||
|
with pytest.raises(InvalidEdgeError):
|
||||||
|
g.add_edge(e1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_polymorphic_does_not_accept_generic_collection():
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = IntegerInvocation(id="2", value=2)
|
||||||
|
n3 = CollectInvocation(id="3")
|
||||||
|
n4 = PolymorphicStringTestInvocation(id="4")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
g.add_node(n3)
|
||||||
|
g.add_node(n4)
|
||||||
|
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||||
|
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||||
|
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||||
|
g.add_edge(e1)
|
||||||
|
g.add_edge(e2)
|
||||||
|
with pytest.raises(InvalidEdgeError):
|
||||||
|
g.add_edge(e3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_accepts_integer():
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = AnyTypeTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_accepts_string():
|
||||||
|
g = Graph()
|
||||||
|
n1 = StringInvocation(id="1", value="banana sundae")
|
||||||
|
n2 = AnyTypeTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_accepts_generic_collection():
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = IntegerInvocation(id="2", value=2)
|
||||||
|
n3 = CollectInvocation(id="3")
|
||||||
|
n4 = AnyTypeTestInvocation(id="4")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
g.add_node(n3)
|
||||||
|
g.add_node(n4)
|
||||||
|
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||||
|
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||||
|
e3 = create_edge(n3.id, "collection", n4.id, "value")
|
||||||
|
g.add_edge(e1)
|
||||||
|
g.add_edge(e2)
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_accepts_prompt_collection():
|
||||||
|
g = Graph()
|
||||||
|
n1 = PromptCollectionTestInvocation(id="1", collection=["banana", "sundae"])
|
||||||
|
n2 = AnyTypeTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "collection", n2.id, "value")
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_accepts_any():
|
||||||
|
g = Graph()
|
||||||
|
n1 = AnyTypeTestInvocation(id="1")
|
||||||
|
n2 = AnyTypeTestInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iterate_accepts_collection():
|
||||||
|
"""We need to update the validation for Collect -> Iterate to traverse to the Iterate
|
||||||
|
node's output and compare that against the item type of the Collect node's collection. Until
|
||||||
|
then, Collect nodes may not output into Iterate nodes."""
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = IntegerInvocation(id="2", value=2)
|
||||||
|
n3 = CollectInvocation(id="3")
|
||||||
|
n4 = IterateInvocation(id="4")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
g.add_node(n3)
|
||||||
|
g.add_node(n4)
|
||||||
|
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||||
|
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||||
|
e3 = create_edge(n3.id, "collection", n4.id, "collection")
|
||||||
|
g.add_edge(e1)
|
||||||
|
g.add_edge(e2)
|
||||||
|
# Once we fix the validation logic as described, this should should not raise an error
|
||||||
|
with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
|
||||||
|
g.add_edge(e3)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Union
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -15,12 +15,12 @@ from invokeai.app.invocations.image import ImageField
|
|||||||
# Define test invocations before importing anything that uses invocations
|
# Define test invocations before importing anything that uses invocations
|
||||||
@invocation_output("test_list_output")
|
@invocation_output("test_list_output")
|
||||||
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||||
collection: list[ImageField] = Field(default_factory=list)
|
collection: list[ImageField] = OutputField(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_list")
|
@invocation("test_list")
|
||||||
class ListPassThroughInvocation(BaseInvocation):
|
class ListPassThroughInvocation(BaseInvocation):
|
||||||
collection: list[ImageField] = Field(default_factory=list)
|
collection: list[ImageField] = InputField(default_factory=list)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||||
@ -28,12 +28,12 @@ class ListPassThroughInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_prompt_output")
|
@invocation_output("test_prompt_output")
|
||||||
class PromptTestInvocationOutput(BaseInvocationOutput):
|
class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||||
prompt: str = Field(default="")
|
prompt: str = OutputField(default="")
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_prompt")
|
@invocation("test_prompt")
|
||||||
class PromptTestInvocation(BaseInvocation):
|
class PromptTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||||
@ -47,13 +47,13 @@ class ErrorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_image_output")
|
@invocation_output("test_image_output")
|
||||||
class ImageTestInvocationOutput(BaseInvocationOutput):
|
class ImageTestInvocationOutput(BaseInvocationOutput):
|
||||||
image: ImageField = Field()
|
image: ImageField = OutputField()
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_text_to_image")
|
@invocation("test_text_to_image")
|
||||||
class TextToImageTestInvocation(BaseInvocation):
|
class TextToImageTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = InputField(default="")
|
||||||
prompt2: str = Field(default="")
|
prompt2: str = InputField(default="")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
@ -61,8 +61,8 @@ class TextToImageTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation("test_image_to_image")
|
@invocation("test_image_to_image")
|
||||||
class ImageToImageTestInvocation(BaseInvocation):
|
class ImageToImageTestInvocation(BaseInvocation):
|
||||||
prompt: str = Field(default="")
|
prompt: str = InputField(default="")
|
||||||
image: Union[ImageField, None] = Field(default=None)
|
image: Union[ImageField, None] = InputField(default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||||
@ -70,17 +70,40 @@ class ImageToImageTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@invocation_output("test_prompt_collection_output")
|
@invocation_output("test_prompt_collection_output")
|
||||||
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||||
collection: list[str] = Field(default_factory=list)
|
collection: list[str] = OutputField(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@invocation("test_prompt_collection")
|
@invocation("test_prompt_collection")
|
||||||
class PromptCollectionTestInvocation(BaseInvocation):
|
class PromptCollectionTestInvocation(BaseInvocation):
|
||||||
collection: list[str] = Field()
|
collection: list[str] = InputField()
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("test_any_output")
|
||||||
|
class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||||
|
value: Any = OutputField()
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("test_any")
|
||||||
|
class AnyTypeTestInvocation(BaseInvocation):
|
||||||
|
value: Any = InputField(default=None)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||||
|
return AnyTypeTestInvocationOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("test_polymorphic")
|
||||||
|
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||||
|
value: Union[str, list[str]] = InputField(default="")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||||
|
if isinstance(self.value, str):
|
||||||
|
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||||
|
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||||
|
|
||||||
|
|
||||||
# Importing these must happen after test invocations are defined or they won't register
|
# Importing these must happen after test invocations are defined or they won't register
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
||||||
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
||||||
|
@ -150,9 +150,9 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
|
||||||
assert len(values) == 8
|
assert len(values) == 8
|
||||||
|
|
||||||
session_adapter = TypeAdapter(GraphExecutionState)
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
||||||
# graph should be serialized
|
# graph should be serialized
|
||||||
ges = session_adapter.validate_json(values[0].session)
|
ges = GraphExecutionStateValidator.validate_json(values[0].session)
|
||||||
|
|
||||||
# graph values should be populated
|
# graph values should be populated
|
||||||
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
assert ges.graph.get_node("1").prompt == "Banana sushi"
|
||||||
@ -161,16 +161,16 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
|
|||||||
assert ges.graph.get_node("4").prompt == "Nissan"
|
assert ges.graph.get_node("4").prompt == "Nissan"
|
||||||
|
|
||||||
# session ids should match deserialized graph
|
# session ids should match deserialized graph
|
||||||
assert [v.session_id for v in values] == [session_adapter.validate_json(v.session).id for v in values]
|
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
|
||||||
|
|
||||||
# should unique session ids
|
# should unique session ids
|
||||||
sids = [v.session_id for v in values]
|
sids = [v.session_id for v in values]
|
||||||
assert len(sids) == len(set(sids))
|
assert len(sids) == len(set(sids))
|
||||||
|
|
||||||
nfv_list_adapter = TypeAdapter(list[NodeFieldValue])
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
||||||
# should have 3 node field values
|
# should have 3 node field values
|
||||||
assert type(values[0].field_values) is str
|
assert type(values[0].field_values) is str
|
||||||
assert len(nfv_list_adapter.validate_json(values[0].field_values)) == 3
|
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
|
||||||
|
|
||||||
# should have batch id and priority
|
# should have batch id and priority
|
||||||
assert all(v.batch_id == b.batch_id for v in values)
|
assert all(v.batch_id == b.batch_id for v in values)
|
||||||
|
Loading…
Reference in New Issue
Block a user