mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/feat/simple-mm2-api
This commit is contained in:
commit
2276f327e5
4
Makefile
4
Makefile
@ -18,6 +18,7 @@ help:
|
|||||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||||
@echo "installer-zip Build the installer .zip file for the current version"
|
@echo "installer-zip Build the installer .zip file for the current version"
|
||||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||||
|
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||||
|
|
||||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||||
ruff:
|
ruff:
|
||||||
@ -70,3 +71,6 @@ installer-zip:
|
|||||||
tag-release:
|
tag-release:
|
||||||
cd installer && ./tag_release.sh
|
cd installer && ./tag_release.sh
|
||||||
|
|
||||||
|
# Generate the OpenAPI Schema for the app
|
||||||
|
openapi:
|
||||||
|
python scripts/generate_openapi_schema.py
|
||||||
|
@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
|
|||||||
|
|
||||||
Check the [configuration docs] for more detail about the settings and how to specify them.
|
Check the [configuration docs] for more detail about the settings and how to specify them.
|
||||||
|
|
||||||
|
## `ModuleNotFoundError: No module named 'controlnet_aux'`
|
||||||
|
|
||||||
|
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
|
||||||
|
|
||||||
|
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
|
||||||
|
|
||||||
|
- Run the Invoke launcher
|
||||||
|
- Choose the developer console option
|
||||||
|
- Run this command: `pip cache remove controlnet_aux`
|
||||||
|
- Close the terminal window
|
||||||
|
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
|
||||||
|
|
||||||
## Out of Memory Issues
|
## Out of Memory Issues
|
||||||
|
|
||||||
The models are large, VRAM is expensive, and you may find yourself
|
The models are large, VRAM is expensive, and you may find yourself
|
||||||
|
@ -3,9 +3,7 @@ import logging
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from inspect import signature
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -13,11 +11,9 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.json_schema import models_json_schema
|
|
||||||
from torch.backends.mps import is_available as is_mps_available
|
from torch.backends.mps import is_available as is_mps_available
|
||||||
|
|
||||||
# for PyCharm:
|
# for PyCharm:
|
||||||
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
|
|||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.events.events_common import EventBase
|
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
@ -45,11 +39,6 @@ from .api.routers import (
|
|||||||
workflows,
|
workflows,
|
||||||
)
|
)
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
UIConfigBase,
|
|
||||||
)
|
|
||||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
|
||||||
|
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
|
|
||||||
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
|
|||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
app.include_router(workflows.workflows_router, prefix="/api")
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
|
|
||||||
|
app.openapi = get_openapi_func(app)
|
||||||
# Build a custom OpenAPI to include all outputs
|
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
|
||||||
def custom_openapi() -> dict[str, Any]:
|
|
||||||
if app.openapi_schema:
|
|
||||||
return app.openapi_schema
|
|
||||||
openapi_schema = get_openapi(
|
|
||||||
title=app.title,
|
|
||||||
description="An API for invoking AI image operations",
|
|
||||||
version="1.0.0",
|
|
||||||
routes=app.routes,
|
|
||||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add all outputs
|
|
||||||
all_invocations = BaseInvocation.get_invocations()
|
|
||||||
output_types = set()
|
|
||||||
output_type_titles = {}
|
|
||||||
for invoker in all_invocations:
|
|
||||||
output_type = signature(invoker.invoke).return_annotation
|
|
||||||
output_types.add(output_type)
|
|
||||||
|
|
||||||
output_schemas = models_json_schema(
|
|
||||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
|
||||||
)
|
|
||||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
|
||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
|
||||||
# This could break in some cases, figure out a better way to do it
|
|
||||||
output_type_titles[schema_key] = output_schema["title"]
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
|
||||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
|
||||||
|
|
||||||
# Some models don't end up in the schemas as standalone definitions
|
|
||||||
additional_schemas = models_json_schema(
|
|
||||||
[
|
|
||||||
(UIConfigBase, "serialization"),
|
|
||||||
(InputFieldJSONSchemaExtra, "serialization"),
|
|
||||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
|
||||||
(ModelIdentifierField, "serialization"),
|
|
||||||
(ProgressImage, "serialization"),
|
|
||||||
],
|
|
||||||
ref_template="#/components/schemas/{model}",
|
|
||||||
)
|
|
||||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
|
||||||
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
|
||||||
for invoker in all_invocations:
|
|
||||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
|
||||||
output_type = signature(obj=invoker.invoke).return_annotation
|
|
||||||
output_type_title = output_type_titles[output_type.__name__]
|
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
|
||||||
invoker_schema["output"] = outputs_ref
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
|
||||||
invoker_schema["class"] = "invocation"
|
|
||||||
|
|
||||||
# Add all event schemas
|
|
||||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
|
||||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
|
||||||
if "$defs" in json_schema:
|
|
||||||
for schema_key, schema in json_schema["$defs"].items():
|
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
|
||||||
del json_schema["$defs"]
|
|
||||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
|
||||||
return app.openapi_schema
|
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
|
|
||||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
_typeadapter_needs_update: ClassVar[bool] = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||||
"""Registers an invocation output."""
|
"""Registers an invocation output."""
|
||||||
cls._output_classes.add(output)
|
cls._output_classes.add(output)
|
||||||
|
cls._typeadapter_needs_update = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||||
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||||
InvocationOutputsUnion = TypeAliasType(
|
AnyInvocationOutput = TypeAliasType(
|
||||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||||
|
cls._typeadapter_needs_update = False
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
|
schema["class"] = "output"
|
||||||
schema["required"].extend(["type"])
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||||
|
_typeadapter_needs_update: ClassVar[bool] = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_type(cls) -> str:
|
def get_type(cls) -> str:
|
||||||
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||||
"""Registers an invocation."""
|
"""Registers an invocation."""
|
||||||
cls._invocation_classes.add(invocation)
|
cls._invocation_classes.add(invocation)
|
||||||
|
cls._typeadapter_needs_update = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||||
if not cls._typeadapter:
|
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||||
InvocationsUnion = TypeAliasType(
|
AnyInvocation = TypeAliasType(
|
||||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||||
)
|
)
|
||||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||||
|
cls._typeadapter_needs_update = False
|
||||||
return cls._typeadapter
|
return cls._typeadapter
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
return signature(cls.invoke).return_annotation
|
return signature(cls.invoke).return_annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||||
if uiconfig is not None:
|
if uiconfig is not None:
|
||||||
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
|
schema["class"] = "invocation"
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
json_schema_extra=json_schema_extra,
|
json_schema_extra=json_schema_extra,
|
||||||
json_schema_serialization_defaults_required=True,
|
json_schema_serialization_defaults_required=False,
|
||||||
coerce_numbers_to_str=True,
|
coerce_numbers_to_str=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
from math import floor
|
from math import floor
|
||||||
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
QUEUE_ITEM_STATUS,
|
QUEUE_ITEM_STATUS,
|
||||||
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItem,
|
SessionQueueItem,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
@ -33,6 +33,7 @@ class EventBase(BaseModel):
|
|||||||
A timestamp is automatically added to the event when it is created.
|
A timestamp is automatically added to the event when it is created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__event_name__: ClassVar[str]
|
||||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||||
|
|
||||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||||
@ -97,7 +98,7 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
__event_name__ = "invocation_started"
|
__event_name__ = "invocation_started"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
@ -135,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
invocation: BaseInvocation,
|
invocation: AnyInvocation,
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
progress_image: ProgressImage,
|
progress_image: ProgressImage,
|
||||||
) -> "InvocationDenoiseProgressEvent":
|
) -> "InvocationDenoiseProgressEvent":
|
||||||
@ -173,11 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
|
|
||||||
__event_name__ = "invocation_complete"
|
__event_name__ = "invocation_complete"
|
||||||
|
|
||||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
||||||
) -> "InvocationCompleteEvent":
|
) -> "InvocationCompleteEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@ -206,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
invocation: BaseInvocation,
|
invocation: AnyInvocation,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
error_traceback: str,
|
error_traceback: str,
|
||||||
|
@ -2,18 +2,19 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
GetCoreSchemaHandler,
|
||||||
GetJsonSchemaHandler,
|
GetJsonSchemaHandler,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
from pydantic.json_schema import JsonSchemaValue
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from pydantic_core import CoreSchema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from invokeai.app.invocations import * # noqa: F401 F403
|
from invokeai.app.invocations import * # noqa: F401 F403
|
||||||
@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
|
|||||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||||
|
|
||||||
|
|
||||||
|
class AnyInvocation(BaseInvocation):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||||
|
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||||
|
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
oneOf: list[dict[str, str]] = []
|
||||||
|
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||||
|
for name in sorted(names):
|
||||||
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||||
|
return {"oneOf": oneOf}
|
||||||
|
|
||||||
|
|
||||||
|
class AnyInvocationOutput(BaseInvocationOutput):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||||
|
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||||
|
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||||
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||||
|
|
||||||
|
oneOf: list[dict[str, str]] = []
|
||||||
|
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||||
|
for name in sorted(names):
|
||||||
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||||
|
return {"oneOf": oneOf}
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||||
edges: list[Edge] = Field(
|
edges: list[Edge] = Field(
|
||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("nodes", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_nodes(cls, v: dict[str, Any]):
|
|
||||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
|
||||||
|
|
||||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
|
||||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
|
||||||
#
|
|
||||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
|
||||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
|
||||||
# invocations will cause a graph to fail if they are used.
|
|
||||||
#
|
|
||||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
|
||||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
|
||||||
#
|
|
||||||
# This same pattern is used in `GraphExecutionState`.
|
|
||||||
|
|
||||||
nodes: dict[str, BaseInvocation] = {}
|
|
||||||
typeadapter = BaseInvocation.get_typeadapter()
|
|
||||||
for node_id, node in v.items():
|
|
||||||
nodes[node_id] = typeadapter.validate_python(node)
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
|
||||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
|
||||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
|
||||||
# the generated schema as options for the `nodes` field.
|
|
||||||
#
|
|
||||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
|
||||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
|
||||||
# expected.
|
|
||||||
#
|
|
||||||
# You might be tempted to do something like this:
|
|
||||||
#
|
|
||||||
# ```py
|
|
||||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
|
||||||
# delattr(cloned_model, "validate_nodes")
|
|
||||||
# cloned_model.model_rebuild(force=True)
|
|
||||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
|
||||||
# ```
|
|
||||||
#
|
|
||||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
|
||||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
|
||||||
#
|
|
||||||
# This same pattern is used in `GraphExecutionState`.
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
|
||||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
|
||||||
nodes: dict[
|
|
||||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
|
||||||
] = Field(description="The nodes in this graph")
|
|
||||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
|
||||||
|
|
||||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation) -> None:
|
def add_node(self, node: BaseInvocation) -> None:
|
||||||
"""Adds a node to a graph
|
"""Adds a node to a graph
|
||||||
|
|
||||||
@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||||
|
|
||||||
# Errors raised when executing nodes
|
# Errors raised when executing nodes
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||||
@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("results", mode="plain")
|
|
||||||
@classmethod
|
|
||||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
|
||||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
|
||||||
|
|
||||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
|
||||||
results: dict[str, BaseInvocationOutput] = {}
|
|
||||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
|
||||||
for result_id, result in v.items():
|
|
||||||
results[result_id] = typeadapter.validate_python(result)
|
|
||||||
return results
|
|
||||||
|
|
||||||
@field_validator("graph")
|
@field_validator("graph")
|
||||||
def graph_is_valid(cls, v: Graph):
|
def graph_is_valid(cls, v: Graph):
|
||||||
"""Validates that the graph is valid"""
|
"""Validates that the graph is valid"""
|
||||||
v.validate_self()
|
v.validate_self()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
|
||||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
|
||||||
class GraphExecutionState(BaseModel):
|
|
||||||
"""Tracks the state of a graph execution"""
|
|
||||||
|
|
||||||
id: str = Field(description="The id of the execution state")
|
|
||||||
graph: Graph = Field(description="The graph being executed")
|
|
||||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
|
||||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
|
||||||
executed_history: list[str] = Field(
|
|
||||||
description="The list of node ids that have been executed, in order of execution"
|
|
||||||
)
|
|
||||||
results: dict[
|
|
||||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
|
||||||
] = Field(description="The results of node executions")
|
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
|
||||||
prepared_source_mapping: dict[str, str] = Field(
|
|
||||||
description="The map of prepared nodes to original graph nodes"
|
|
||||||
)
|
|
||||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
|
||||||
description="The map of original graph nodes to prepared nodes"
|
|
||||||
)
|
|
||||||
|
|
||||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
def next(self) -> Optional[BaseInvocation]:
|
def next(self) -> Optional[BaseInvocation]:
|
||||||
"""Gets the next node ready to execute."""
|
"""Gets the next node ready to execute."""
|
||||||
|
|
||||||
|
116
invokeai/app/util/custom_openapi.py
Normal file
116
invokeai/app/util/custom_openapi.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||||
|
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.events.events_common import EventBase
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
|
|
||||||
|
|
||||||
|
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||||
|
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
|
||||||
|
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
|
||||||
|
component_schema."""
|
||||||
|
|
||||||
|
defs = component_schema.pop("$defs", {})
|
||||||
|
for schema_key, json_schema in defs.items():
|
||||||
|
if schema_key in openapi_schema["components"]["schemas"]:
|
||||||
|
continue
|
||||||
|
openapi_schema["components"]["schemas"][schema_key] = json_schema
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_func(
|
||||||
|
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||||
|
) -> Callable[[], dict[str, Any]]:
|
||||||
|
"""Gets the OpenAPI schema generator function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app (FastAPI): The FastAPI app to generate the schema for.
|
||||||
|
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
|
||||||
|
generated schema before returning it. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
|
||||||
|
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
|
||||||
|
matches FastAPI's default schema generation caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def openapi() -> dict[str, Any]:
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
description="An API for invoking AI image operations",
|
||||||
|
version="1.0.0",
|
||||||
|
routes=app.routes,
|
||||||
|
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll create a map of invocation type to output schema to make some types simpler on the client.
|
||||||
|
invocation_output_map_properties: dict[str, Any] = {}
|
||||||
|
invocation_output_map_required: list[str] = []
|
||||||
|
|
||||||
|
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||||
|
for output in BaseInvocationOutput.get_outputs():
|
||||||
|
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||||
|
move_defs_to_top_level(openapi_schema, json_schema)
|
||||||
|
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||||
|
|
||||||
|
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||||
|
# property, so we'll just do it all manually.
|
||||||
|
for invocation in BaseInvocation.get_invocations():
|
||||||
|
json_schema = invocation.model_json_schema(
|
||||||
|
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||||
|
)
|
||||||
|
move_defs_to_top_level(openapi_schema, json_schema)
|
||||||
|
output_title = invocation.get_output_annotation().__name__
|
||||||
|
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
|
||||||
|
json_schema["output"] = outputs_ref
|
||||||
|
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
|
||||||
|
|
||||||
|
# Add this invocation and its output to the output map
|
||||||
|
invocation_type = invocation.get_type()
|
||||||
|
invocation_output_map_properties[invocation_type] = json_schema["output"]
|
||||||
|
invocation_output_map_required.append(invocation_type)
|
||||||
|
|
||||||
|
# Add the output map to the schema
|
||||||
|
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": invocation_output_map_properties,
|
||||||
|
"required": invocation_output_map_required,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
|
||||||
|
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
|
||||||
|
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
|
||||||
|
# `models_json_schema` seems to work fine.
|
||||||
|
additional_models = [
|
||||||
|
*EventBase.get_events(),
|
||||||
|
UIConfigBase,
|
||||||
|
InputFieldJSONSchemaExtra,
|
||||||
|
OutputFieldJSONSchemaExtra,
|
||||||
|
ModelIdentifierField,
|
||||||
|
ProgressImage,
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_schemas = models_json_schema(
|
||||||
|
[(m, "serialization") for m in additional_models],
|
||||||
|
ref_template="#/components/schemas/{model}",
|
||||||
|
)
|
||||||
|
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
|
||||||
|
move_defs_to_top_level(openapi_schema, additional_schemas[1])
|
||||||
|
|
||||||
|
if post_transform is not None:
|
||||||
|
openapi_schema = post_transform(openapi_schema)
|
||||||
|
|
||||||
|
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
return openapi
|
@ -53,5 +53,5 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
if not self._cache.lazy_offloading:
|
if not self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(0)
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Textual Inversion wrapper class."""
|
"""Textual Inversion wrapper class."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# no type hints for BaseTextualInversionManager?
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||||
pad_tokens: Dict[int, List[int]]
|
|
||||||
tokenizer: CLIPTokenizer
|
|
||||||
|
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.pad_tokens = {}
|
self.pad_tokens: dict[int, list[int]] = {}
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||||
|
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
|
||||||
|
|
||||||
|
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
|
||||||
|
mapping of tokens to token_ids:
|
||||||
|
```
|
||||||
|
<ti_dog>: 49408
|
||||||
|
<ti_dog-!pad-1>: 49409
|
||||||
|
<ti_dog-!pad-2>: 49410
|
||||||
|
<ti_dog-!pad-3>: 49411
|
||||||
|
```
|
||||||
|
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
|
||||||
|
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
|
||||||
|
"""
|
||||||
|
# Short circuit if there are no pad tokens to save a little time.
|
||||||
if len(self.pad_tokens) == 0:
|
if len(self.pad_tokens) == 0:
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
|
||||||
|
# this assumption here.
|
||||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||||
raise ValueError("token_ids must not start with bos_token_id")
|
raise ValueError("token_ids must not start with bos_token_id")
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
raise ValueError("token_ids must not end with eos_token_id")
|
raise ValueError("token_ids must not end with eos_token_id")
|
||||||
|
|
||||||
new_token_ids = []
|
# Expand any TI tokens to their corresponding pad tokens.
|
||||||
|
new_token_ids: list[int] = []
|
||||||
for token_id in token_ids:
|
for token_id in token_ids:
|
||||||
new_token_ids.append(token_id)
|
new_token_ids.append(token_id)
|
||||||
if token_id in self.pad_tokens:
|
if token_id in self.pad_tokens:
|
||||||
new_token_ids.extend(self.pad_tokens[token_id])
|
new_token_ids.extend(self.pad_tokens[token_id])
|
||||||
|
|
||||||
# Do not exceed the max model input size
|
# Do not exceed the max model input size. The -2 here is compensating for
|
||||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
|
||||||
# which first removes and then adds back the start and end tokens.
|
max_length = self.tokenizer.model_max_length - 2
|
||||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
|
||||||
if len(new_token_ids) > max_length:
|
if len(new_token_ids) > max_length:
|
||||||
|
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
|
||||||
|
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
|
||||||
|
# prompts.
|
||||||
new_token_ids = new_token_ids[0:max_length]
|
new_token_ids = new_token_ids[0:max_length]
|
||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
@ -148,6 +148,8 @@
|
|||||||
"viewingDesc": "Review images in a large gallery view",
|
"viewingDesc": "Review images in a large gallery view",
|
||||||
"editing": "Editing",
|
"editing": "Editing",
|
||||||
"editingDesc": "Edit on the Control Layers canvas",
|
"editingDesc": "Edit on the Control Layers canvas",
|
||||||
|
"comparing": "Comparing",
|
||||||
|
"comparingDesc": "Comparing two images",
|
||||||
"enabled": "Enabled",
|
"enabled": "Enabled",
|
||||||
"disabled": "Disabled"
|
"disabled": "Disabled"
|
||||||
},
|
},
|
||||||
@ -375,7 +377,23 @@
|
|||||||
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
||||||
"bulkDownloadFailed": "Download Failed",
|
"bulkDownloadFailed": "Download Failed",
|
||||||
"problemDeletingImages": "Problem Deleting Images",
|
"problemDeletingImages": "Problem Deleting Images",
|
||||||
"problemDeletingImagesDesc": "One or more images could not be deleted"
|
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||||
|
"viewerImage": "Viewer Image",
|
||||||
|
"compareImage": "Compare Image",
|
||||||
|
"openInViewer": "Open in Viewer",
|
||||||
|
"selectForCompare": "Select for Compare",
|
||||||
|
"selectAnImageToCompare": "Select an Image to Compare",
|
||||||
|
"slider": "Slider",
|
||||||
|
"sideBySide": "Side-by-Side",
|
||||||
|
"hover": "Hover",
|
||||||
|
"swapImages": "Swap Images",
|
||||||
|
"compareOptions": "Comparison Options",
|
||||||
|
"stretchToFit": "Stretch to Fit",
|
||||||
|
"exitCompare": "Exit Compare",
|
||||||
|
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||||
|
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||||
|
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||||
|
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"searchHotkeys": "Search Hotkeys",
|
"searchHotkeys": "Search Hotkeys",
|
||||||
|
@ -19,6 +19,13 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
|||||||
return extendTheme({
|
return extendTheme({
|
||||||
..._theme,
|
..._theme,
|
||||||
direction,
|
direction,
|
||||||
|
shadows: {
|
||||||
|
..._theme.shadows,
|
||||||
|
selectedForCompare:
|
||||||
|
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
||||||
|
hoverSelectedForCompare:
|
||||||
|
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}, [direction]);
|
}, [direction]);
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ import {
|
|||||||
isControlAdapterLayer,
|
isControlAdapterLayer,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@ -139,7 +138,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
|||||||
|
|
||||||
// We still have to check the output type
|
// We still have to check the output type
|
||||||
assert(
|
assert(
|
||||||
isImageOutput(invocationCompleteAction.payload.data.result),
|
invocationCompleteAction.payload.data.result.type === 'image_output',
|
||||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||||
);
|
);
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||||
|
@ -9,7 +9,6 @@ import {
|
|||||||
selectControlAdapterById,
|
selectControlAdapterById,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
|||||||
);
|
);
|
||||||
|
|
||||||
// We still have to check the output type
|
// We still have to check the output type
|
||||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||||
|
|
||||||
// Wait for the ImageDTO to be received
|
// Wait for the ImageDTO to be received
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
import { imagesSelectors } from 'services/api/util';
|
import { imagesSelectors } from 'services/api/util';
|
||||||
@ -11,6 +11,7 @@ export const galleryImageClicked = createAction<{
|
|||||||
shiftKey: boolean;
|
shiftKey: boolean;
|
||||||
ctrlKey: boolean;
|
ctrlKey: boolean;
|
||||||
metaKey: boolean;
|
metaKey: boolean;
|
||||||
|
altKey: boolean;
|
||||||
}>('gallery/imageClicked');
|
}>('gallery/imageClicked');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -28,7 +29,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: galleryImageClicked,
|
actionCreator: galleryImageClicked,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
|
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListImagesQueryArgs(state);
|
const queryArgs = selectListImagesQueryArgs(state);
|
||||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||||
@ -41,7 +42,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
|||||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||||
const selection = state.gallery.selection;
|
const selection = state.gallery.selection;
|
||||||
|
|
||||||
if (shiftKey) {
|
if (altKey) {
|
||||||
|
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
|
||||||
|
dispatch(imageToCompareChanged(null));
|
||||||
|
} else {
|
||||||
|
dispatch(imageToCompareChanged(imageDTO));
|
||||||
|
}
|
||||||
|
} else if (shiftKey) {
|
||||||
const rangeEndImageName = imageDTO.image_name;
|
const rangeEndImageName = imageDTO.image_name;
|
||||||
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
||||||
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
||||||
|
@ -14,7 +14,8 @@ import {
|
|||||||
rgLayerIPAdapterImageChanged,
|
rgLayerIPAdapterImageChanged,
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
} from 'features/controlLayers/store/controlLayersSlice';
|
||||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||||
|
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@ -30,6 +31,9 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const log = logger('dnd');
|
const log = logger('dnd');
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
|
if (!isValidDrop(overData, activeData)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||||
log.debug({ activeData, overData }, 'Image dropped');
|
log.debug({ activeData, overData }, 'Image dropped');
|
||||||
@ -50,6 +54,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||||
|
dispatch(isImageViewerOpenChanged(true));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,24 +187,18 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO
|
* Image selected for compare
|
||||||
* Image selection dropped on node image collection field
|
|
||||||
*/
|
*/
|
||||||
// if (
|
if (
|
||||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
overData.actionType === 'SELECT_FOR_COMPARE' &&
|
||||||
// activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
// activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
// ) {
|
) {
|
||||||
// const { fieldName, nodeId } = overData.context;
|
const { imageDTO } = activeData.payload;
|
||||||
// dispatch(
|
dispatch(imageToCompareChanged(imageDTO));
|
||||||
// fieldValueChanged({
|
dispatch(isImageViewerOpenChanged(true));
|
||||||
// nodeId,
|
return;
|
||||||
// fieldName,
|
}
|
||||||
// value: [activeData.payload.imageDTO],
|
|
||||||
// })
|
|
||||||
// );
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on user board
|
* Image dropped on user board
|
||||||
|
@ -11,7 +11,6 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
|||||||
|
|
||||||
const { result, invocation_source_id } = data;
|
const { result, invocation_source_id } = data;
|
||||||
// This complete event has an associated image output
|
// This complete event has an associated image output
|
||||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
|
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||||
const { image_name } = data.result.image;
|
const { image_name } = data.result.image;
|
||||||
const { canvas, gallery } = getState();
|
const { canvas, gallery } = getState();
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
|||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||||
import type { Templates } from 'features/nodes/store/types';
|
import type { Templates } from 'features/nodes/store/types';
|
||||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||||
@ -65,9 +65,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
requestAnimationFrame(() => {
|
$needsFit.set(true);
|
||||||
$flow.get()?.fitView();
|
|
||||||
});
|
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
if (e instanceof WorkflowVersionError) {
|
if (e instanceof WorkflowVersionError) {
|
||||||
// The workflow version was not recognized in the valid list of versions
|
// The workflow version was not recognized in the valid list of versions
|
||||||
|
@ -35,6 +35,7 @@ type IAIDndImageProps = FlexProps & {
|
|||||||
draggableData?: TypesafeDraggableData;
|
draggableData?: TypesafeDraggableData;
|
||||||
dropLabel?: ReactNode;
|
dropLabel?: ReactNode;
|
||||||
isSelected?: boolean;
|
isSelected?: boolean;
|
||||||
|
isSelectedForCompare?: boolean;
|
||||||
thumbnail?: boolean;
|
thumbnail?: boolean;
|
||||||
noContentFallback?: ReactElement;
|
noContentFallback?: ReactElement;
|
||||||
useThumbailFallback?: boolean;
|
useThumbailFallback?: boolean;
|
||||||
@ -61,6 +62,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
draggableData,
|
draggableData,
|
||||||
dropLabel,
|
dropLabel,
|
||||||
isSelected = false,
|
isSelected = false,
|
||||||
|
isSelectedForCompare = false,
|
||||||
thumbnail = false,
|
thumbnail = false,
|
||||||
noContentFallback = defaultNoContentFallback,
|
noContentFallback = defaultNoContentFallback,
|
||||||
uploadElement = defaultUploadElement,
|
uploadElement = defaultUploadElement,
|
||||||
@ -165,7 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
data-testid={dataTestId}
|
data-testid={dataTestId}
|
||||||
/>
|
/>
|
||||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||||
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
|
<SelectionOverlay
|
||||||
|
isSelected={isSelected}
|
||||||
|
isSelectedForCompare={isSelectedForCompare}
|
||||||
|
isHovered={withHoverOverlay ? isHovered : false}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{!imageDTO && !isUploadDisabled && (
|
{!imageDTO && !isUploadDisabled && (
|
||||||
|
@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
|||||||
pointerEvents={active ? 'auto' : 'none'}
|
pointerEvents={active ? 'auto' : 'none'}
|
||||||
>
|
>
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||||
</AnimatePresence>
|
</AnimatePresence>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
@ -3,10 +3,17 @@ import { memo, useMemo } from 'react';
|
|||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
|
isSelectedForCompare: boolean;
|
||||||
isHovered: boolean;
|
isHovered: boolean;
|
||||||
};
|
};
|
||||||
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
|
||||||
const shadow = useMemo(() => {
|
const shadow = useMemo(() => {
|
||||||
|
if (isSelectedForCompare && isHovered) {
|
||||||
|
return 'hoverSelectedForCompare';
|
||||||
|
}
|
||||||
|
if (isSelectedForCompare && !isHovered) {
|
||||||
|
return 'selectedForCompare';
|
||||||
|
}
|
||||||
if (isSelected && isHovered) {
|
if (isSelected && isHovered) {
|
||||||
return 'hoverSelected';
|
return 'hoverSelected';
|
||||||
}
|
}
|
||||||
@ -17,7 +24,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
|||||||
return 'hoverUnselected';
|
return 'hoverUnselected';
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}, [isHovered, isSelected]);
|
}, [isHovered, isSelected, isSelectedForCompare]);
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
className="selection-box"
|
className="selection-box"
|
||||||
@ -27,7 +34,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
|||||||
bottom={0}
|
bottom={0}
|
||||||
insetInlineStart={0}
|
insetInlineStart={0}
|
||||||
borderRadius="base"
|
borderRadius="base"
|
||||||
opacity={isSelected ? 1 : 0.7}
|
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
|
||||||
transitionProperty="common"
|
transitionProperty="common"
|
||||||
transitionDuration="0.1s"
|
transitionDuration="0.1s"
|
||||||
pointerEvents="none"
|
pointerEvents="none"
|
||||||
|
21
invokeai/frontend/web/src/common/hooks/useBoolean.ts
Normal file
21
invokeai/frontend/web/src/common/hooks/useBoolean.ts
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
|
|
||||||
|
export const useBoolean = (initialValue: boolean) => {
|
||||||
|
const [isTrue, set] = useState(initialValue);
|
||||||
|
const setTrue = useCallback(() => set(true), []);
|
||||||
|
const setFalse = useCallback(() => set(false), []);
|
||||||
|
const toggle = useCallback(() => set((v) => !v), []);
|
||||||
|
|
||||||
|
const api = useMemo(
|
||||||
|
() => ({
|
||||||
|
isTrue,
|
||||||
|
set,
|
||||||
|
setTrue,
|
||||||
|
setFalse,
|
||||||
|
toggle,
|
||||||
|
}),
|
||||||
|
[isTrue, set, setTrue, setFalse, toggle]
|
||||||
|
);
|
||||||
|
|
||||||
|
return api;
|
||||||
|
};
|
@ -1,3 +1,7 @@
|
|||||||
export const stopPropagation = (e: React.MouseEvent) => {
|
export const stopPropagation = (e: React.MouseEvent) => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const preventDefault = (e: React.MouseEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
};
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
import { deepClone } from 'common/util/deepClone';
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { merge, omit } from 'lodash-es';
|
import { merge, omit } from 'lodash-es';
|
||||||
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
import type {
|
||||||
|
AnyInvocation,
|
||||||
|
BaseModelType,
|
||||||
|
ControlNetModelConfig,
|
||||||
|
ImageDTO,
|
||||||
|
T2IAdapterModelConfig,
|
||||||
|
} from 'services/api/types';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
|
|
||||||
const zId = z.string().min(1);
|
const zId = z.string().min(1);
|
||||||
@ -147,7 +153,7 @@ const zBeginEndStepPct = z
|
|||||||
|
|
||||||
const zControlAdapterBase = z.object({
|
const zControlAdapterBase = z.object({
|
||||||
id: zId,
|
id: zId,
|
||||||
weight: z.number().gte(0).lte(1),
|
weight: z.number().gte(-1).lte(2),
|
||||||
image: zImageWithDims.nullable(),
|
image: zImageWithDims.nullable(),
|
||||||
processedImage: zImageWithDims.nullable(),
|
processedImage: zImageWithDims.nullable(),
|
||||||
processorConfig: zProcessorConfig.nullable(),
|
processorConfig: zProcessorConfig.nullable(),
|
||||||
@ -183,7 +189,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar
|
|||||||
export const zIPAdapterConfigV2 = z.object({
|
export const zIPAdapterConfigV2 = z.object({
|
||||||
id: zId,
|
id: zId,
|
||||||
type: z.literal('ip_adapter'),
|
type: z.literal('ip_adapter'),
|
||||||
weight: z.number().gte(0).lte(1),
|
weight: z.number().gte(-1).lte(2),
|
||||||
method: zIPMethodV2,
|
method: zIPMethodV2,
|
||||||
image: zImageWithDims.nullable(),
|
image: zImageWithDims.nullable(),
|
||||||
model: zModelIdentifierField.nullable(),
|
model: zModelIdentifierField.nullable(),
|
||||||
@ -216,10 +222,7 @@ type ProcessorData<T extends ProcessorTypeV2> = {
|
|||||||
labelTKey: string;
|
labelTKey: string;
|
||||||
descriptionTKey: string;
|
descriptionTKey: string;
|
||||||
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
|
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
|
||||||
buildNode(
|
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
|
||||||
image: ImageWithDims,
|
|
||||||
config: Extract<ProcessorConfig, { type: T }>
|
|
||||||
): Extract<Graph['nodes'][string], { type: T }>;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
|
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
|
||||||
|
@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
|||||||
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||||
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||||
const STAGE_BG_DATAURL =
|
export const STAGE_BG_DATAURL =
|
||||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
|
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
|
||||||
|
|
||||||
const mapId = (object: { id: string }) => object.id;
|
const mapId = (object: { id: string }) => object.id;
|
||||||
|
@ -18,7 +18,7 @@ type BaseDropData = {
|
|||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type CurrentImageDropData = BaseDropData & {
|
export type CurrentImageDropData = BaseDropData & {
|
||||||
actionType: 'SET_CURRENT_IMAGE';
|
actionType: 'SET_CURRENT_IMAGE';
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -79,6 +79,14 @@ export type RemoveFromBoardDropData = BaseDropData & {
|
|||||||
actionType: 'REMOVE_FROM_BOARD';
|
actionType: 'REMOVE_FROM_BOARD';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type SelectForCompareDropData = BaseDropData & {
|
||||||
|
actionType: 'SELECT_FOR_COMPARE';
|
||||||
|
context: {
|
||||||
|
firstImageName?: string | null;
|
||||||
|
secondImageName?: string | null;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export type TypesafeDroppableData =
|
export type TypesafeDroppableData =
|
||||||
| CurrentImageDropData
|
| CurrentImageDropData
|
||||||
| ControlAdapterDropData
|
| ControlAdapterDropData
|
||||||
@ -89,7 +97,8 @@ export type TypesafeDroppableData =
|
|||||||
| CALayerImageDropData
|
| CALayerImageDropData
|
||||||
| IPALayerImageDropData
|
| IPALayerImageDropData
|
||||||
| RGLayerIPAdapterImageDropData
|
| RGLayerIPAdapterImageDropData
|
||||||
| IILayerImageDropData;
|
| IILayerImageDropData
|
||||||
|
| SelectForCompareDropData;
|
||||||
|
|
||||||
type BaseDragData = {
|
type BaseDragData = {
|
||||||
id: string;
|
id: string;
|
||||||
@ -134,7 +143,7 @@ export type UseDraggableTypesafeReturnValue = Omit<ReturnType<typeof useOriginal
|
|||||||
over: TypesafeOver | null;
|
over: TypesafeOver | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export interface TypesafeActive extends Omit<Active, 'data'> {
|
interface TypesafeActive extends Omit<Active, 'data'> {
|
||||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types';
|
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||||
|
|
||||||
export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => {
|
export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => {
|
||||||
if (!overData || !active?.data.current) {
|
if (!overData || !activeData) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { actionType } = overData;
|
const { actionType } = overData;
|
||||||
const { payloadType } = active.data.current;
|
const { payloadType } = activeData;
|
||||||
|
|
||||||
if (overData.id === active.data.current.id) {
|
if (overData.id === activeData.id) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,6 +29,8 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
|||||||
return payloadType === 'IMAGE_DTO';
|
return payloadType === 'IMAGE_DTO';
|
||||||
case 'SET_NODES_IMAGE':
|
case 'SET_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO';
|
return payloadType === 'IMAGE_DTO';
|
||||||
|
case 'SELECT_FOR_COMPARE':
|
||||||
|
return payloadType === 'IMAGE_DTO';
|
||||||
case 'ADD_TO_BOARD': {
|
case 'ADD_TO_BOARD': {
|
||||||
// If the board is the same, don't allow the drop
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
@ -40,7 +42,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
|||||||
|
|
||||||
// Check if the image's board is the board we are dragging onto
|
// Check if the image's board is the board we are dragging onto
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
const { imageDTO } = active.data.current.payload;
|
const { imageDTO } = activeData.payload;
|
||||||
const currentBoard = imageDTO.board_id ?? 'none';
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
const destinationBoard = overData.context.boardId;
|
const destinationBoard = overData.context.boardId;
|
||||||
|
|
||||||
@ -49,7 +51,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
|||||||
|
|
||||||
if (payloadType === 'GALLERY_SELECTION') {
|
if (payloadType === 'GALLERY_SELECTION') {
|
||||||
// Assume all images are on the same board - this is true for the moment
|
// Assume all images are on the same board - this is true for the moment
|
||||||
const currentBoard = active.data.current.payload.boardId;
|
const currentBoard = activeData.payload.boardId;
|
||||||
const destinationBoard = overData.context.boardId;
|
const destinationBoard = overData.context.boardId;
|
||||||
return currentBoard !== destinationBoard;
|
return currentBoard !== destinationBoard;
|
||||||
}
|
}
|
||||||
@ -67,14 +69,14 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
|||||||
|
|
||||||
// Check if the image's board is the board we are dragging onto
|
// Check if the image's board is the board we are dragging onto
|
||||||
if (payloadType === 'IMAGE_DTO') {
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
const { imageDTO } = active.data.current.payload;
|
const { imageDTO } = activeData.payload;
|
||||||
const currentBoard = imageDTO.board_id ?? 'none';
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
|
|
||||||
return currentBoard !== 'none';
|
return currentBoard !== 'none';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (payloadType === 'GALLERY_SELECTION') {
|
if (payloadType === 'GALLERY_SELECTION') {
|
||||||
const currentBoard = active.data.current.payload.boardId;
|
const currentBoard = activeData.payload.boardId;
|
||||||
return currentBoard !== 'none';
|
return currentBoard !== 'none';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
|||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{isSelectedForAutoAdd && <AutoAddIcon />}
|
{isSelectedForAutoAdd && <AutoAddIcon />}
|
||||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
|
||||||
<Flex
|
<Flex
|
||||||
position="absolute"
|
position="absolute"
|
||||||
bottom={0}
|
bottom={0}
|
||||||
|
@ -117,7 +117,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
>
|
>
|
||||||
{boardName}
|
{boardName}
|
||||||
</Flex>
|
</Flex>
|
||||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
|
||||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
|
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
|
||||||
</Flex>
|
</Flex>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
@ -10,6 +10,7 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
|
|||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||||
|
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
@ -27,6 +28,7 @@ import {
|
|||||||
PiDownloadSimpleBold,
|
PiDownloadSimpleBold,
|
||||||
PiFlowArrowBold,
|
PiFlowArrowBold,
|
||||||
PiFoldersBold,
|
PiFoldersBold,
|
||||||
|
PiImagesBold,
|
||||||
PiPlantBold,
|
PiPlantBold,
|
||||||
PiQuotesBold,
|
PiQuotesBold,
|
||||||
PiShareFatBold,
|
PiShareFatBold,
|
||||||
@ -44,6 +46,7 @@ type SingleSelectionMenuItemsProps = {
|
|||||||
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||||
const { imageDTO } = props;
|
const { imageDTO } = props;
|
||||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||||
|
const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isCanvasEnabled = useFeatureStatus('canvas');
|
const isCanvasEnabled = useFeatureStatus('canvas');
|
||||||
@ -117,6 +120,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
downloadImage(imageDTO.image_url, imageDTO.image_name);
|
downloadImage(imageDTO.image_url, imageDTO.image_name);
|
||||||
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
|
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
|
||||||
|
|
||||||
|
const handleSelectImageForCompare = useCallback(() => {
|
||||||
|
dispatch(imageToCompareChanged(imageDTO));
|
||||||
|
}, [dispatch, imageDTO]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
|
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
|
||||||
@ -130,6 +137,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
|
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
|
||||||
{t('parameters.downloadImage')}
|
{t('parameters.downloadImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
<MenuItem icon={<PiImagesBold />} isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}>
|
||||||
|
{t('gallery.selectForCompare')}
|
||||||
|
</MenuItem>
|
||||||
<MenuDivider />
|
<MenuDivider />
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
|
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
|
||||||
|
@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab
|
|||||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
||||||
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
|
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
|
||||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import type { MouseEvent } from 'react';
|
import type { MouseEvent } from 'react';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -46,6 +46,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||||
|
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
|
||||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||||
|
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
@ -105,6 +106,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
|
|
||||||
const onDoubleClick = useCallback(() => {
|
const onDoubleClick = useCallback(() => {
|
||||||
dispatch(isImageViewerOpenChanged(true));
|
dispatch(isImageViewerOpenChanged(true));
|
||||||
|
dispatch(imageToCompareChanged(null));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleMouseOut = useCallback(() => {
|
const handleMouseOut = useCallback(() => {
|
||||||
@ -152,6 +154,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
imageDTO={imageDTO}
|
imageDTO={imageDTO}
|
||||||
draggableData={draggableData}
|
draggableData={draggableData}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
|
isSelectedForCompare={isSelectedForCompare}
|
||||||
minSize={0}
|
minSize={0}
|
||||||
imageSx={imageSx}
|
imageSx={imageSx}
|
||||||
isDropDisabled={true}
|
isDropDisabled={true}
|
||||||
|
@ -0,0 +1,140 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
ButtonGroup,
|
||||||
|
Flex,
|
||||||
|
Icon,
|
||||||
|
IconButton,
|
||||||
|
Kbd,
|
||||||
|
ListItem,
|
||||||
|
Tooltip,
|
||||||
|
UnorderedList,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
comparedImagesSwapped,
|
||||||
|
comparisonFitChanged,
|
||||||
|
comparisonModeChanged,
|
||||||
|
comparisonModeCycled,
|
||||||
|
imageToCompareChanged,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { Trans, useTranslation } from 'react-i18next';
|
||||||
|
import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
export const CompareToolbar = memo(() => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
|
||||||
|
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||||
|
const setComparisonModeSlider = useCallback(() => {
|
||||||
|
dispatch(comparisonModeChanged('slider'));
|
||||||
|
}, [dispatch]);
|
||||||
|
const setComparisonModeSideBySide = useCallback(() => {
|
||||||
|
dispatch(comparisonModeChanged('side-by-side'));
|
||||||
|
}, [dispatch]);
|
||||||
|
const setComparisonModeHover = useCallback(() => {
|
||||||
|
dispatch(comparisonModeChanged('hover'));
|
||||||
|
}, [dispatch]);
|
||||||
|
const swapImages = useCallback(() => {
|
||||||
|
dispatch(comparedImagesSwapped());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys('c', swapImages, [swapImages]);
|
||||||
|
const toggleComparisonFit = useCallback(() => {
|
||||||
|
dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain'));
|
||||||
|
}, [dispatch, comparisonFit]);
|
||||||
|
const exitCompare = useCallback(() => {
|
||||||
|
dispatch(imageToCompareChanged(null));
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys('esc', exitCompare, [exitCompare]);
|
||||||
|
const nextMode = useCallback(() => {
|
||||||
|
dispatch(comparisonModeCycled());
|
||||||
|
}, [dispatch]);
|
||||||
|
useHotkeys('m', nextMode, [nextMode]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" gap={2}>
|
||||||
|
<Flex flex={1} justifyContent="center">
|
||||||
|
<Flex gap={2} marginInlineEnd="auto">
|
||||||
|
<IconButton
|
||||||
|
icon={<PiSwapBold />}
|
||||||
|
aria-label={`${t('gallery.swapImages')} (C)`}
|
||||||
|
tooltip={`${t('gallery.swapImages')} (C)`}
|
||||||
|
onClick={swapImages}
|
||||||
|
/>
|
||||||
|
{comparisonMode !== 'side-by-side' && (
|
||||||
|
<IconButton
|
||||||
|
aria-label={t('gallery.stretchToFit')}
|
||||||
|
tooltip={t('gallery.stretchToFit')}
|
||||||
|
onClick={toggleComparisonFit}
|
||||||
|
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
|
||||||
|
variant="outline"
|
||||||
|
icon={<PiArrowsOutBold />}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<Flex flex={1} gap={4} justifyContent="center">
|
||||||
|
<ButtonGroup variant="outline">
|
||||||
|
<Button
|
||||||
|
flexShrink={0}
|
||||||
|
onClick={setComparisonModeSlider}
|
||||||
|
colorScheme={comparisonMode === 'slider' ? 'invokeBlue' : 'base'}
|
||||||
|
>
|
||||||
|
{t('gallery.slider')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
flexShrink={0}
|
||||||
|
onClick={setComparisonModeSideBySide}
|
||||||
|
colorScheme={comparisonMode === 'side-by-side' ? 'invokeBlue' : 'base'}
|
||||||
|
>
|
||||||
|
{t('gallery.sideBySide')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
flexShrink={0}
|
||||||
|
onClick={setComparisonModeHover}
|
||||||
|
colorScheme={comparisonMode === 'hover' ? 'invokeBlue' : 'base'}
|
||||||
|
>
|
||||||
|
{t('gallery.hover')}
|
||||||
|
</Button>
|
||||||
|
</ButtonGroup>
|
||||||
|
</Flex>
|
||||||
|
<Flex flex={1} justifyContent="center">
|
||||||
|
<Flex gap={2} marginInlineStart="auto" alignItems="center">
|
||||||
|
<Tooltip label={<CompareHelp />}>
|
||||||
|
<Flex alignItems="center">
|
||||||
|
<Icon boxSize={8} color="base.500" as={PiQuestion} lineHeight={0} />
|
||||||
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
|
<IconButton
|
||||||
|
icon={<PiXBold />}
|
||||||
|
aria-label={`${t('gallery.exitCompare')} (Esc)`}
|
||||||
|
tooltip={`${t('gallery.exitCompare')} (Esc)`}
|
||||||
|
onClick={exitCompare}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
CompareToolbar.displayName = 'CompareToolbar';
|
||||||
|
|
||||||
|
const CompareHelp = () => {
|
||||||
|
return (
|
||||||
|
<UnorderedList>
|
||||||
|
<ListItem>
|
||||||
|
<Trans i18nKey="gallery.compareHelp1" components={{ Kbd: <Kbd /> }}></Trans>
|
||||||
|
</ListItem>
|
||||||
|
<ListItem>
|
||||||
|
<Trans i18nKey="gallery.compareHelp2" components={{ Kbd: <Kbd /> }}></Trans>
|
||||||
|
</ListItem>
|
||||||
|
<ListItem>
|
||||||
|
<Trans i18nKey="gallery.compareHelp3" components={{ Kbd: <Kbd /> }}></Trans>
|
||||||
|
</ListItem>
|
||||||
|
<ListItem>
|
||||||
|
<Trans i18nKey="gallery.compareHelp4" components={{ Kbd: <Kbd /> }}></Trans>
|
||||||
|
</ListItem>
|
||||||
|
</UnorderedList>
|
||||||
|
);
|
||||||
|
};
|
@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
import type { TypesafeDraggableData } from 'features/dnd/types';
|
||||||
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
||||||
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
||||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||||
@ -22,21 +22,7 @@ const selectLastSelectedImageName = createSelector(
|
|||||||
(lastSelectedImage) => lastSelectedImage?.image_name
|
(lastSelectedImage) => lastSelectedImage?.image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
type Props = {
|
const CurrentImagePreview = () => {
|
||||||
isDragDisabled?: boolean;
|
|
||||||
isDropDisabled?: boolean;
|
|
||||||
withNextPrevButtons?: boolean;
|
|
||||||
withMetadata?: boolean;
|
|
||||||
alwaysShowProgress?: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
const CurrentImagePreview = ({
|
|
||||||
isDragDisabled = false,
|
|
||||||
isDropDisabled = false,
|
|
||||||
withNextPrevButtons = true,
|
|
||||||
withMetadata = true,
|
|
||||||
alwaysShowProgress = false,
|
|
||||||
}: Props) => {
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
||||||
const imageName = useAppSelector(selectLastSelectedImageName);
|
const imageName = useAppSelector(selectLastSelectedImageName);
|
||||||
@ -55,14 +41,6 @@ const CurrentImagePreview = ({
|
|||||||
}
|
}
|
||||||
}, [imageDTO]);
|
}, [imageDTO]);
|
||||||
|
|
||||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
|
||||||
() => ({
|
|
||||||
id: 'current-image',
|
|
||||||
actionType: 'SET_CURRENT_IMAGE',
|
|
||||||
}),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Show and hide the next/prev buttons on mouse move
|
// Show and hide the next/prev buttons on mouse move
|
||||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
|
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
|
||||||
const timeoutId = useRef(0);
|
const timeoutId = useRef(0);
|
||||||
@ -86,30 +64,27 @@ const CurrentImagePreview = ({
|
|||||||
justifyContent="center"
|
justifyContent="center"
|
||||||
position="relative"
|
position="relative"
|
||||||
>
|
>
|
||||||
{hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? (
|
{hasDenoiseProgress && shouldShowProgressInViewer ? (
|
||||||
<ProgressImage />
|
<ProgressImage />
|
||||||
) : (
|
) : (
|
||||||
<IAIDndImage
|
<IAIDndImage
|
||||||
imageDTO={imageDTO}
|
imageDTO={imageDTO}
|
||||||
droppableData={droppableData}
|
|
||||||
draggableData={draggableData}
|
draggableData={draggableData}
|
||||||
isDragDisabled={isDragDisabled}
|
isDropDisabled={true}
|
||||||
isDropDisabled={isDropDisabled}
|
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
fitContainer
|
fitContainer
|
||||||
useThumbailFallback
|
useThumbailFallback
|
||||||
dropLabel={t('gallery.setCurrentImage')}
|
|
||||||
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
|
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
|
||||||
dataTestId="image-preview"
|
dataTestId="image-preview"
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{shouldShowImageDetails && imageDTO && withMetadata && (
|
{shouldShowImageDetails && imageDTO && (
|
||||||
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
|
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
|
||||||
<ImageMetadataViewer image={imageDTO} />
|
<ImageMetadataViewer image={imageDTO} />
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && (
|
{shouldShowNextPrevButtons && imageDTO && (
|
||||||
<Box
|
<Box
|
||||||
as={motion.div}
|
as={motion.div}
|
||||||
key="nextPrevButtons"
|
key="nextPrevButtons"
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||||
|
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
|
||||||
|
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
|
||||||
|
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
|
||||||
|
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiImagesBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
containerDims: Dimensions;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ImageComparison = memo(({ containerDims }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
|
||||||
|
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
|
||||||
|
|
||||||
|
if (!firstImage || !secondImage) {
|
||||||
|
// Should rarely/never happen - we don't render this component unless we have images to compare
|
||||||
|
return <IAINoContentFallback label={t('gallery.selectAnImageToCompare')} icon={PiImagesBold} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (comparisonMode === 'slider') {
|
||||||
|
return <ImageComparisonSlider containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (comparisonMode === 'side-by-side') {
|
||||||
|
return (
|
||||||
|
<ImageComparisonSideBySide containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (comparisonMode === 'hover') {
|
||||||
|
return <ImageComparisonHover containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ImageComparison.displayName = 'ImageComparison';
|
@ -0,0 +1,47 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
|
import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types';
|
||||||
|
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||||
|
import { memo, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import { selectComparisonImages } from './common';
|
||||||
|
|
||||||
|
const setCurrentImageDropData: CurrentImageDropData = {
|
||||||
|
id: 'current-image',
|
||||||
|
actionType: 'SET_CURRENT_IMAGE',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ImageComparisonDroppable = memo(() => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const imageViewer = useImageViewer();
|
||||||
|
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
|
||||||
|
const selectForCompareDropData = useMemo<SelectForCompareDropData>(
|
||||||
|
() => ({
|
||||||
|
id: 'image-comparison',
|
||||||
|
actionType: 'SELECT_FOR_COMPARE',
|
||||||
|
context: {
|
||||||
|
firstImageName: firstImage?.image_name,
|
||||||
|
secondImageName: secondImage?.image_name,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
[firstImage?.image_name, secondImage?.image_name]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!imageViewer.isOpen) {
|
||||||
|
return (
|
||||||
|
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||||
|
<IAIDroppable data={setCurrentImageDropData} dropLabel={t('gallery.openInViewer')} />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||||
|
<IAIDroppable data={selectForCompareDropData} dropLabel={t('gallery.selectForCompare')} />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ImageComparisonDroppable.displayName = 'ImageComparisonDroppable';
|
@ -0,0 +1,117 @@
|
|||||||
|
import { Box, Flex, Image } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useBoolean } from 'common/hooks/useBoolean';
|
||||||
|
import { preventDefault } from 'common/util/stopPropagation';
|
||||||
|
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||||
|
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||||
|
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||||
|
import { memo, useMemo, useRef } from 'react';
|
||||||
|
|
||||||
|
import type { ComparisonProps } from './common';
|
||||||
|
import { fitDimsToContainer, getSecondImageDims } from './common';
|
||||||
|
|
||||||
|
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
|
||||||
|
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||||
|
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||||
|
const mouseOver = useBoolean(false);
|
||||||
|
const fittedDims = useMemo<Dimensions>(
|
||||||
|
() => fitDimsToContainer(containerDims, firstImage),
|
||||||
|
[containerDims, firstImage]
|
||||||
|
);
|
||||||
|
const compareImageDims = useMemo<Dimensions>(
|
||||||
|
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
|
||||||
|
[comparisonFit, fittedDims, firstImage, secondImage]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||||
|
<Flex
|
||||||
|
id="image-comparison-wrapper"
|
||||||
|
w="full"
|
||||||
|
h="full"
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
position="absolute"
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
ref={imageContainerRef}
|
||||||
|
position="relative"
|
||||||
|
id="image-comparison-hover-image-container"
|
||||||
|
w={fittedDims.width}
|
||||||
|
h={fittedDims.height}
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
userSelect="none"
|
||||||
|
overflow="hidden"
|
||||||
|
borderRadius="base"
|
||||||
|
>
|
||||||
|
<Image
|
||||||
|
id="image-comparison-hover-first-image"
|
||||||
|
src={firstImage.image_url}
|
||||||
|
fallbackSrc={firstImage.thumbnail_url}
|
||||||
|
w={fittedDims.width}
|
||||||
|
h={fittedDims.height}
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
objectFit="cover"
|
||||||
|
objectPosition="top left"
|
||||||
|
/>
|
||||||
|
<ImageComparisonLabel type="first" opacity={mouseOver.isTrue ? 0 : 1} />
|
||||||
|
|
||||||
|
<Box
|
||||||
|
id="image-comparison-hover-second-image-container"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
right={0}
|
||||||
|
bottom={0}
|
||||||
|
overflow="hidden"
|
||||||
|
opacity={mouseOver.isTrue ? 1 : 0}
|
||||||
|
transitionDuration="0.2s"
|
||||||
|
transitionProperty="common"
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
id="image-comparison-hover-bg"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
right={0}
|
||||||
|
bottom={0}
|
||||||
|
backgroundImage={STAGE_BG_DATAURL}
|
||||||
|
backgroundRepeat="repeat"
|
||||||
|
opacity={0.2}
|
||||||
|
/>
|
||||||
|
<Image
|
||||||
|
position="relative"
|
||||||
|
id="image-comparison-hover-second-image"
|
||||||
|
src={secondImage.image_url}
|
||||||
|
fallbackSrc={secondImage.thumbnail_url}
|
||||||
|
w={compareImageDims.width}
|
||||||
|
h={compareImageDims.height}
|
||||||
|
maxW={fittedDims.width}
|
||||||
|
maxH={fittedDims.height}
|
||||||
|
objectFit={comparisonFit}
|
||||||
|
objectPosition="top left"
|
||||||
|
/>
|
||||||
|
<ImageComparisonLabel type="second" opacity={mouseOver.isTrue ? 1 : 0} />
|
||||||
|
</Box>
|
||||||
|
<Box
|
||||||
|
id="image-comparison-hover-interaction-overlay"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
right={0}
|
||||||
|
bottom={0}
|
||||||
|
left={0}
|
||||||
|
onMouseOver={mouseOver.setTrue}
|
||||||
|
onMouseOut={mouseOver.setFalse}
|
||||||
|
onContextMenu={preventDefault}
|
||||||
|
userSelect="none"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ImageComparisonHover.displayName = 'ImageComparisonHover';
|
@ -0,0 +1,33 @@
|
|||||||
|
import type { TextProps } from '@invoke-ai/ui-library';
|
||||||
|
import { Text } from '@invoke-ai/ui-library';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import { DROP_SHADOW } from './common';
|
||||||
|
|
||||||
|
type Props = TextProps & {
|
||||||
|
type: 'first' | 'second';
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
return (
|
||||||
|
<Text
|
||||||
|
position="absolute"
|
||||||
|
bottom={4}
|
||||||
|
insetInlineEnd={type === 'first' ? undefined : 4}
|
||||||
|
insetInlineStart={type === 'first' ? 4 : undefined}
|
||||||
|
textOverflow="clip"
|
||||||
|
whiteSpace="nowrap"
|
||||||
|
filter={DROP_SHADOW}
|
||||||
|
color="base.50"
|
||||||
|
transitionDuration="0.2s"
|
||||||
|
transitionProperty="common"
|
||||||
|
{...rest}
|
||||||
|
>
|
||||||
|
{type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ImageComparisonLabel.displayName = 'ImageComparisonLabel';
|
@ -0,0 +1,70 @@
|
|||||||
|
import { Flex, Image } from '@invoke-ai/ui-library';
|
||||||
|
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
|
||||||
|
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||||
|
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||||
|
import { memo, useCallback, useRef } from 'react';
|
||||||
|
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
|
||||||
|
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||||
|
|
||||||
|
export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => {
|
||||||
|
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||||
|
const onDoubleClickHandle = useCallback(() => {
|
||||||
|
if (!panelGroupRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
panelGroupRef.current.setLayout([50, 50]);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||||
|
<Flex w="full" h="full" maxW="full" maxH="full" position="absolute" alignItems="center" justifyContent="center">
|
||||||
|
<PanelGroup ref={panelGroupRef} direction="horizontal" id="image-comparison-side-by-side">
|
||||||
|
<Panel minSize={20}>
|
||||||
|
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={firstImage.width / firstImage.height}>
|
||||||
|
<Image
|
||||||
|
id="image-comparison-side-by-side-first-image"
|
||||||
|
w={firstImage.width}
|
||||||
|
h={firstImage.height}
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
src={firstImage.image_url}
|
||||||
|
fallbackSrc={firstImage.thumbnail_url}
|
||||||
|
objectFit="contain"
|
||||||
|
borderRadius="base"
|
||||||
|
/>
|
||||||
|
<ImageComparisonLabel type="first" />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</Panel>
|
||||||
|
<ResizeHandle
|
||||||
|
id="image-comparison-side-by-side-handle"
|
||||||
|
onDoubleClick={onDoubleClickHandle}
|
||||||
|
orientation="vertical"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Panel minSize={20}>
|
||||||
|
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
|
||||||
|
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={secondImage.width / secondImage.height}>
|
||||||
|
<Image
|
||||||
|
id="image-comparison-side-by-side-first-image"
|
||||||
|
w={secondImage.width}
|
||||||
|
h={secondImage.height}
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
src={secondImage.image_url}
|
||||||
|
fallbackSrc={secondImage.thumbnail_url}
|
||||||
|
objectFit="contain"
|
||||||
|
borderRadius="base"
|
||||||
|
/>
|
||||||
|
<ImageComparisonLabel type="second" />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</Panel>
|
||||||
|
</PanelGroup>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide';
|
@ -0,0 +1,215 @@
|
|||||||
|
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { preventDefault } from 'common/util/stopPropagation';
|
||||||
|
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||||
|
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||||
|
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||||
|
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||||
|
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
import type { ComparisonProps } from './common';
|
||||||
|
import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common';
|
||||||
|
|
||||||
|
const INITIAL_POS = '50%';
|
||||||
|
const HANDLE_WIDTH = 2;
|
||||||
|
const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`;
|
||||||
|
const HANDLE_HITBOX = 20;
|
||||||
|
const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
|
||||||
|
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
|
||||||
|
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
|
||||||
|
|
||||||
|
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
|
||||||
|
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||||
|
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
|
||||||
|
const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX);
|
||||||
|
// How wide the first image is
|
||||||
|
const [width, setWidth] = useState(INITIAL_POS);
|
||||||
|
const handleRef = useRef<HTMLDivElement>(null);
|
||||||
|
// To manage aspect ratios, we need to know the size of the container
|
||||||
|
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||||
|
// To keep things smooth, we use RAF to update the handle position & gate it to 60fps
|
||||||
|
const rafRef = useRef<number | null>(null);
|
||||||
|
const lastMoveTimeRef = useRef<number>(0);
|
||||||
|
|
||||||
|
const fittedDims = useMemo<Dimensions>(
|
||||||
|
() => fitDimsToContainer(containerDims, firstImage),
|
||||||
|
[containerDims, firstImage]
|
||||||
|
);
|
||||||
|
|
||||||
|
const compareImageDims = useMemo<Dimensions>(
|
||||||
|
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
|
||||||
|
[comparisonFit, fittedDims, firstImage, secondImage]
|
||||||
|
);
|
||||||
|
|
||||||
|
const updateHandlePos = useCallback((clientX: number) => {
|
||||||
|
if (!handleRef.current || !imageContainerRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lastMoveTimeRef.current = performance.now();
|
||||||
|
const { x, width } = imageContainerRef.current.getBoundingClientRect();
|
||||||
|
const rawHandlePos = ((clientX - x) * 100) / width;
|
||||||
|
const handleWidthPct = (HANDLE_WIDTH * 100) / width;
|
||||||
|
const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos));
|
||||||
|
setWidth(`${newHandlePos}%`);
|
||||||
|
setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const onMouseMove = useCallback(
|
||||||
|
(e: MouseEvent) => {
|
||||||
|
if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) {
|
||||||
|
rafRef.current = window.requestAnimationFrame(() => {
|
||||||
|
updateHandlePos(e.clientX);
|
||||||
|
rafRef.current = null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[updateHandlePos]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onMouseUp = useCallback(() => {
|
||||||
|
window.removeEventListener('mousemove', onMouseMove);
|
||||||
|
}, [onMouseMove]);
|
||||||
|
|
||||||
|
const onMouseDown = useCallback(
|
||||||
|
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||||
|
// Update the handle position immediately on click
|
||||||
|
updateHandlePos(e.clientX);
|
||||||
|
window.addEventListener('mouseup', onMouseUp, { once: true });
|
||||||
|
window.addEventListener('mousemove', onMouseMove);
|
||||||
|
},
|
||||||
|
[onMouseMove, onMouseUp, updateHandlePos]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(
|
||||||
|
() => () => {
|
||||||
|
if (rafRef.current !== null) {
|
||||||
|
cancelAnimationFrame(rafRef.current);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||||
|
<Flex
|
||||||
|
id="image-comparison-wrapper"
|
||||||
|
w="full"
|
||||||
|
h="full"
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
position="absolute"
|
||||||
|
alignItems="center"
|
||||||
|
justifyContent="center"
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
ref={imageContainerRef}
|
||||||
|
position="relative"
|
||||||
|
id="image-comparison-image-container"
|
||||||
|
w={fittedDims.width}
|
||||||
|
h={fittedDims.height}
|
||||||
|
maxW="full"
|
||||||
|
maxH="full"
|
||||||
|
userSelect="none"
|
||||||
|
overflow="hidden"
|
||||||
|
borderRadius="base"
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
id="image-comparison-bg"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
right={0}
|
||||||
|
bottom={0}
|
||||||
|
backgroundImage={STAGE_BG_DATAURL}
|
||||||
|
backgroundRepeat="repeat"
|
||||||
|
opacity={0.2}
|
||||||
|
/>
|
||||||
|
<Image
|
||||||
|
position="relative"
|
||||||
|
id="image-comparison-second-image"
|
||||||
|
src={secondImage.image_url}
|
||||||
|
fallbackSrc={secondImage.thumbnail_url}
|
||||||
|
w={compareImageDims.width}
|
||||||
|
h={compareImageDims.height}
|
||||||
|
maxW={fittedDims.width}
|
||||||
|
maxH={fittedDims.height}
|
||||||
|
objectFit={comparisonFit}
|
||||||
|
objectPosition="top left"
|
||||||
|
/>
|
||||||
|
<ImageComparisonLabel type="second" />
|
||||||
|
<Box
|
||||||
|
id="image-comparison-first-image-container"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
right={0}
|
||||||
|
bottom={0}
|
||||||
|
w={width}
|
||||||
|
overflow="hidden"
|
||||||
|
>
|
||||||
|
<Image
|
||||||
|
id="image-comparison-first-image"
|
||||||
|
src={firstImage.image_url}
|
||||||
|
fallbackSrc={firstImage.thumbnail_url}
|
||||||
|
w={fittedDims.width}
|
||||||
|
h={fittedDims.height}
|
||||||
|
objectFit="cover"
|
||||||
|
objectPosition="top left"
|
||||||
|
/>
|
||||||
|
<ImageComparisonLabel type="first" />
|
||||||
|
</Box>
|
||||||
|
<Flex
|
||||||
|
id="image-comparison-handle"
|
||||||
|
ref={handleRef}
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
bottom={0}
|
||||||
|
left={left}
|
||||||
|
w={HANDLE_HITBOX_PX}
|
||||||
|
cursor="ew-resize"
|
||||||
|
filter={DROP_SHADOW}
|
||||||
|
opacity={0.8}
|
||||||
|
color="base.50"
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
id="image-comparison-handle-divider"
|
||||||
|
w={HANDLE_WIDTH_PX}
|
||||||
|
h="full"
|
||||||
|
bg="currentColor"
|
||||||
|
shadow="dark-lg"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
left={HANDLE_INNER_LEFT_PX}
|
||||||
|
/>
|
||||||
|
<Flex
|
||||||
|
id="image-comparison-handle-icons"
|
||||||
|
gap={4}
|
||||||
|
position="absolute"
|
||||||
|
left="50%"
|
||||||
|
top="50%"
|
||||||
|
transform="translate(-50%, 0)"
|
||||||
|
filter={DROP_SHADOW}
|
||||||
|
>
|
||||||
|
<Icon as={PiCaretLeftBold} />
|
||||||
|
<Icon as={PiCaretRightBold} />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<Box
|
||||||
|
id="image-comparison-interaction-overlay"
|
||||||
|
position="absolute"
|
||||||
|
top={0}
|
||||||
|
right={0}
|
||||||
|
bottom={0}
|
||||||
|
left={0}
|
||||||
|
onMouseDown={onMouseDown}
|
||||||
|
onContextMenu={preventDefault}
|
||||||
|
userSelect="none"
|
||||||
|
cursor="ew-resize"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ImageComparisonSlider.displayName = 'ImageComparisonSlider';
|
@ -1,36 +1,16 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
|
||||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview';
|
||||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison';
|
||||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import { memo } from 'react';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { useMeasure } from 'react-use';
|
||||||
import { memo, useMemo } from 'react';
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
|
||||||
|
|
||||||
import CurrentImageButtons from './CurrentImageButtons';
|
import { useImageViewer } from './useImageViewer';
|
||||||
import CurrentImagePreview from './CurrentImagePreview';
|
|
||||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
|
||||||
|
|
||||||
const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows'];
|
|
||||||
|
|
||||||
export const ImageViewer = memo(() => {
|
export const ImageViewer = memo(() => {
|
||||||
const { isOpen, onToggle, onClose } = useImageViewer();
|
const imageViewer = useImageViewer();
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
|
||||||
const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]);
|
|
||||||
const shouldShowViewer = useMemo(() => {
|
|
||||||
if (!isViewerEnabled) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return isOpen;
|
|
||||||
}, [isOpen, isViewerEnabled]);
|
|
||||||
|
|
||||||
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
|
|
||||||
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);
|
|
||||||
|
|
||||||
if (!shouldShowViewer) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -46,25 +26,13 @@ export const ImageViewer = memo(() => {
|
|||||||
rowGap={4}
|
rowGap={4}
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
justifyContent="center"
|
justifyContent="center"
|
||||||
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
|
|
||||||
>
|
>
|
||||||
<Flex w="full" gap={2}>
|
{imageViewer.isComparing && <CompareToolbar />}
|
||||||
<Flex flex={1} justifyContent="center">
|
{!imageViewer.isComparing && <ViewerToolbar />}
|
||||||
<Flex gap={2} marginInlineEnd="auto">
|
<Box ref={containerRef} w="full" h="full">
|
||||||
<ToggleProgressButton />
|
{!imageViewer.isComparing && <CurrentImagePreview />}
|
||||||
<ToggleMetadataViewerButton />
|
{imageViewer.isComparing && <ImageComparison containerDims={containerDims} />}
|
||||||
</Flex>
|
</Box>
|
||||||
</Flex>
|
|
||||||
<Flex flex={1} gap={2} justifyContent="center">
|
|
||||||
<CurrentImageButtons />
|
|
||||||
</Flex>
|
|
||||||
<Flex flex={1} justifyContent="center">
|
|
||||||
<Flex gap={2} marginInlineStart="auto">
|
|
||||||
<ViewerToggleMenu />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
<CurrentImagePreview />
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
|
||||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
|
||||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
import CurrentImageButtons from './CurrentImageButtons';
|
|
||||||
import CurrentImagePreview from './CurrentImagePreview';
|
|
||||||
|
|
||||||
export const ImageViewerWorkflows = memo(() => {
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
layerStyle="first"
|
|
||||||
borderRadius="base"
|
|
||||||
position="absolute"
|
|
||||||
flexDirection="column"
|
|
||||||
top={0}
|
|
||||||
right={0}
|
|
||||||
bottom={0}
|
|
||||||
left={0}
|
|
||||||
p={2}
|
|
||||||
rowGap={4}
|
|
||||||
alignItems="center"
|
|
||||||
justifyContent="center"
|
|
||||||
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
|
|
||||||
>
|
|
||||||
<Flex w="full" gap={2}>
|
|
||||||
<Flex flex={1} justifyContent="center">
|
|
||||||
<Flex gap={2} marginInlineEnd="auto">
|
|
||||||
<ToggleProgressButton />
|
|
||||||
<ToggleMetadataViewerButton />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
<Flex flex={1} gap={2} justifyContent="center">
|
|
||||||
<CurrentImageButtons />
|
|
||||||
</Flex>
|
|
||||||
<Flex flex={1} justifyContent="center">
|
|
||||||
<Flex gap={2} marginInlineStart="auto" />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
<CurrentImagePreview />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
ImageViewerWorkflows.displayName = 'ImageViewerWorkflows';
|
|
@ -9,33 +9,35 @@ import {
|
|||||||
PopoverTrigger,
|
PopoverTrigger,
|
||||||
Text,
|
Text,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
|
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
|
||||||
|
|
||||||
import { useImageViewer } from './useImageViewer';
|
|
||||||
|
|
||||||
export const ViewerToggleMenu = () => {
|
export const ViewerToggleMenu = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onClose, onOpen } = useImageViewer();
|
const imageViewer = useImageViewer();
|
||||||
|
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
|
||||||
|
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover isLazy>
|
<Popover isLazy>
|
||||||
<PopoverTrigger>
|
<PopoverTrigger>
|
||||||
<Button variant="outline" data-testid="toggle-viewer-menu-button">
|
<Button variant="outline" data-testid="toggle-viewer-menu-button" pointerEvents="auto">
|
||||||
<Flex gap={3} w="full" alignItems="center">
|
<Flex gap={3} w="full" alignItems="center">
|
||||||
{isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
|
{imageViewer.isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
|
||||||
<Text fontSize="md">{isOpen ? t('common.viewing') : t('common.editing')}</Text>
|
<Text fontSize="md">{imageViewer.isOpen ? t('common.viewing') : t('common.editing')}</Text>
|
||||||
<Icon as={PiCaretDownBold} />
|
<Icon as={PiCaretDownBold} />
|
||||||
</Flex>
|
</Flex>
|
||||||
</Button>
|
</Button>
|
||||||
</PopoverTrigger>
|
</PopoverTrigger>
|
||||||
<PopoverContent p={2}>
|
<PopoverContent p={2} pointerEvents="auto">
|
||||||
<PopoverArrow />
|
<PopoverArrow />
|
||||||
<PopoverBody>
|
<PopoverBody>
|
||||||
<Flex flexDir="column">
|
<Flex flexDir="column">
|
||||||
<Button onClick={onOpen} variant="ghost" h="auto" w="auto" p={2}>
|
<Button onClick={imageViewer.onOpen} variant="ghost" h="auto" w="auto" p={2}>
|
||||||
<Flex gap={2} w="full">
|
<Flex gap={2} w="full">
|
||||||
<Icon as={PiCheckBold} visibility={isOpen ? 'visible' : 'hidden'} />
|
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'visible' : 'hidden'} />
|
||||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||||
<Text fontWeight="semibold" color="base.100">
|
<Text fontWeight="semibold" color="base.100">
|
||||||
{t('common.viewing')}
|
{t('common.viewing')}
|
||||||
@ -46,9 +48,9 @@ export const ViewerToggleMenu = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Button>
|
</Button>
|
||||||
<Button onClick={onClose} variant="ghost" h="auto" w="auto" p={2}>
|
<Button onClick={imageViewer.onClose} variant="ghost" h="auto" w="auto" p={2}>
|
||||||
<Flex gap={2} w="full">
|
<Flex gap={2} w="full">
|
||||||
<Icon as={PiCheckBold} visibility={isOpen ? 'hidden' : 'visible'} />
|
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'hidden' : 'visible'} />
|
||||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||||
<Text fontWeight="semibold" color="base.100">
|
<Text fontWeight="semibold" color="base.100">
|
||||||
{t('common.editing')}
|
{t('common.editing')}
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||||
|
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
import CurrentImageButtons from './CurrentImageButtons';
|
||||||
|
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||||
|
|
||||||
|
export const ViewerToolbar = memo(() => {
|
||||||
|
const tab = useAppSelector(activeTabNameSelector);
|
||||||
|
return (
|
||||||
|
<Flex w="full" gap={2}>
|
||||||
|
<Flex flex={1} justifyContent="center">
|
||||||
|
<Flex gap={2} marginInlineEnd="auto">
|
||||||
|
<ToggleProgressButton />
|
||||||
|
<ToggleMetadataViewerButton />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<Flex flex={1} gap={2} justifyContent="center">
|
||||||
|
<CurrentImageButtons />
|
||||||
|
</Flex>
|
||||||
|
<Flex flex={1} justifyContent="center">
|
||||||
|
<Flex gap={2} marginInlineStart="auto">
|
||||||
|
{tab !== 'workflows' && <ViewerToggleMenu />}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ViewerToolbar.displayName = 'ViewerToolbar';
|
@ -0,0 +1,64 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||||
|
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||||
|
import type { ComparisonFit } from 'features/gallery/store/types';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
|
||||||
|
|
||||||
|
export type ComparisonProps = {
|
||||||
|
firstImage: ImageDTO;
|
||||||
|
secondImage: ImageDTO;
|
||||||
|
containerDims: Dimensions;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const fitDimsToContainer = (containerDims: Dimensions, imageDims: Dimensions): Dimensions => {
|
||||||
|
// Fall back to the image's dimensions if the container has no dimensions
|
||||||
|
if (containerDims.width === 0 || containerDims.height === 0) {
|
||||||
|
return { width: imageDims.width, height: imageDims.height };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to the image's dimensions if the image fits within the container
|
||||||
|
if (imageDims.width <= containerDims.width && imageDims.height <= containerDims.height) {
|
||||||
|
return { width: imageDims.width, height: imageDims.height };
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetAspectRatio = containerDims.width / containerDims.height;
|
||||||
|
const imageAspectRatio = imageDims.width / imageDims.height;
|
||||||
|
|
||||||
|
let width: number;
|
||||||
|
let height: number;
|
||||||
|
|
||||||
|
if (imageAspectRatio > targetAspectRatio) {
|
||||||
|
// Image is wider than container's aspect ratio
|
||||||
|
width = containerDims.width;
|
||||||
|
height = width / imageAspectRatio;
|
||||||
|
} else {
|
||||||
|
// Image is taller than container's aspect ratio
|
||||||
|
height = containerDims.height;
|
||||||
|
width = height * imageAspectRatio;
|
||||||
|
}
|
||||||
|
return { width, height };
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the dimensions of the second image in a comparison based on the comparison fit mode.
|
||||||
|
*/
|
||||||
|
export const getSecondImageDims = (
|
||||||
|
comparisonFit: ComparisonFit,
|
||||||
|
fittedDims: Dimensions,
|
||||||
|
firstImageDims: Dimensions,
|
||||||
|
secondImageDims: Dimensions
|
||||||
|
): Dimensions => {
|
||||||
|
const width =
|
||||||
|
comparisonFit === 'fill' ? fittedDims.width : (fittedDims.width * secondImageDims.width) / firstImageDims.width;
|
||||||
|
const height =
|
||||||
|
comparisonFit === 'fill' ? fittedDims.height : (fittedDims.height * secondImageDims.height) / firstImageDims.height;
|
||||||
|
|
||||||
|
return { width, height };
|
||||||
|
};
|
||||||
|
export const selectComparisonImages = createMemoizedSelector(selectGallerySlice, (gallerySlice) => {
|
||||||
|
const firstImage = gallerySlice.selection.slice(-1)[0] ?? null;
|
||||||
|
const secondImage = gallerySlice.imageToCompare;
|
||||||
|
return { firstImage, secondImage };
|
||||||
|
});
|
@ -0,0 +1,31 @@
|
|||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
export const useImageViewer = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const isComparing = useAppSelector((s) => s.gallery.imageToCompare !== null);
|
||||||
|
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
|
||||||
|
|
||||||
|
const onClose = useCallback(() => {
|
||||||
|
if (isComparing && isOpen) {
|
||||||
|
dispatch(imageToCompareChanged(null));
|
||||||
|
} else {
|
||||||
|
dispatch(isImageViewerOpenChanged(false));
|
||||||
|
}
|
||||||
|
}, [dispatch, isComparing, isOpen]);
|
||||||
|
|
||||||
|
const onOpen = useCallback(() => {
|
||||||
|
dispatch(isImageViewerOpenChanged(true));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
const onToggle = useCallback(() => {
|
||||||
|
if (isComparing && isOpen) {
|
||||||
|
dispatch(imageToCompareChanged(null));
|
||||||
|
} else {
|
||||||
|
dispatch(isImageViewerOpenChanged(!isOpen));
|
||||||
|
}
|
||||||
|
}, [dispatch, isComparing, isOpen]);
|
||||||
|
|
||||||
|
return { isOpen, onOpen, onClose, onToggle, isComparing };
|
||||||
|
};
|
@ -1,22 +0,0 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
|
|
||||||
export const useImageViewer = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
|
|
||||||
|
|
||||||
const onClose = useCallback(() => {
|
|
||||||
dispatch(isImageViewerOpenChanged(false));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
const onOpen = useCallback(() => {
|
|
||||||
dispatch(isImageViewerOpenChanged(true));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
const onToggle = useCallback(() => {
|
|
||||||
dispatch(isImageViewerOpenChanged(!isOpen));
|
|
||||||
}, [dispatch, isOpen]);
|
|
||||||
|
|
||||||
return { isOpen, onOpen, onClose, onToggle };
|
|
||||||
};
|
|
@ -14,7 +14,7 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
|||||||
const NextPrevImageButtons = () => {
|
const NextPrevImageButtons = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
areMoreImagesAvailable,
|
areMoreImagesAvailable,
|
||||||
@ -30,7 +30,7 @@ const NextPrevImageButtons = () => {
|
|||||||
aria-label={t('accessibility.previousImage')}
|
aria-label={t('accessibility.previousImage')}
|
||||||
icon={<PiCaretLeftBold size={64} />}
|
icon={<PiCaretLeftBold size={64} />}
|
||||||
variant="unstyled"
|
variant="unstyled"
|
||||||
onClick={handleLeftImage}
|
onClick={prevImage}
|
||||||
boxSize={16}
|
boxSize={16}
|
||||||
sx={nextPrevButtonStyles}
|
sx={nextPrevButtonStyles}
|
||||||
/>
|
/>
|
||||||
@ -42,7 +42,7 @@ const NextPrevImageButtons = () => {
|
|||||||
aria-label={t('accessibility.nextImage')}
|
aria-label={t('accessibility.nextImage')}
|
||||||
icon={<PiCaretRightBold size={64} />}
|
icon={<PiCaretRightBold size={64} />}
|
||||||
variant="unstyled"
|
variant="unstyled"
|
||||||
onClick={handleRightImage}
|
onClick={nextImage}
|
||||||
boxSize={16}
|
boxSize={16}
|
||||||
sx={nextPrevButtonStyles}
|
sx={nextPrevButtonStyles}
|
||||||
/>
|
/>
|
||||||
|
@ -27,16 +27,16 @@ export const useGalleryHotkeys = () => {
|
|||||||
useGalleryNavigation();
|
useGalleryNavigation();
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'left',
|
['left', 'alt+left'],
|
||||||
() => {
|
(e) => {
|
||||||
canNavigateGallery && handleLeftImage();
|
canNavigateGallery && handleLeftImage(e.altKey);
|
||||||
},
|
},
|
||||||
[handleLeftImage, canNavigateGallery]
|
[handleLeftImage, canNavigateGallery]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'right',
|
['right', 'alt+right'],
|
||||||
() => {
|
(e) => {
|
||||||
if (!canNavigateGallery) {
|
if (!canNavigateGallery) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -45,29 +45,29 @@ export const useGalleryHotkeys = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!isOnLastImage) {
|
if (!isOnLastImage) {
|
||||||
handleRightImage();
|
handleRightImage(e.altKey);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'up',
|
['up', 'alt+up'],
|
||||||
() => {
|
(e) => {
|
||||||
handleUpImage();
|
handleUpImage(e.altKey);
|
||||||
},
|
},
|
||||||
{ preventDefault: true },
|
{ preventDefault: true },
|
||||||
[handleUpImage]
|
[handleUpImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'down',
|
['down', 'alt+down'],
|
||||||
() => {
|
(e) => {
|
||||||
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
||||||
handleLoadMoreImages();
|
handleLoadMoreImages();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
handleDownImage();
|
handleDownImage(e.altKey);
|
||||||
},
|
},
|
||||||
{ preventDefault: true },
|
{ preventDefault: true },
|
||||||
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
|
import { useAltModifier } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||||
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
||||||
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
||||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { getIsVisible } from 'features/gallery/util/getIsVisible';
|
import { getIsVisible } from 'features/gallery/util/getIsVisible';
|
||||||
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
|
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
@ -106,10 +106,12 @@ const getImageFuncs = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
type UseGalleryNavigationReturn = {
|
type UseGalleryNavigationReturn = {
|
||||||
handleLeftImage: () => void;
|
handleLeftImage: (alt?: boolean) => void;
|
||||||
handleRightImage: () => void;
|
handleRightImage: (alt?: boolean) => void;
|
||||||
handleUpImage: () => void;
|
handleUpImage: (alt?: boolean) => void;
|
||||||
handleDownImage: () => void;
|
handleDownImage: (alt?: boolean) => void;
|
||||||
|
prevImage: () => void;
|
||||||
|
nextImage: () => void;
|
||||||
isOnFirstImage: boolean;
|
isOnFirstImage: boolean;
|
||||||
isOnLastImage: boolean;
|
isOnLastImage: boolean;
|
||||||
areImagesBelowCurrent: boolean;
|
areImagesBelowCurrent: boolean;
|
||||||
@ -123,7 +125,15 @@ type UseGalleryNavigationReturn = {
|
|||||||
*/
|
*/
|
||||||
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
const alt = useAltModifier();
|
||||||
|
const lastSelectedImage = useAppSelector((s) => {
|
||||||
|
const lastSelected = s.gallery.selection.slice(-1)[0] ?? null;
|
||||||
|
if (alt) {
|
||||||
|
return s.gallery.imageToCompare ?? lastSelected;
|
||||||
|
} else {
|
||||||
|
return lastSelected;
|
||||||
|
}
|
||||||
|
});
|
||||||
const {
|
const {
|
||||||
queryResult: { data },
|
queryResult: { data },
|
||||||
} = useGalleryImages();
|
} = useGalleryImages();
|
||||||
@ -136,7 +146,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
|||||||
}, [lastSelectedImage, data]);
|
}, [lastSelectedImage, data]);
|
||||||
|
|
||||||
const handleNavigation = useCallback(
|
const handleNavigation = useCallback(
|
||||||
(direction: 'left' | 'right' | 'up' | 'down') => {
|
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
|
||||||
if (!data) {
|
if (!data) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -144,10 +154,14 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
|||||||
if (!image || index === lastSelectedImageIndex) {
|
if (!image || index === lastSelectedImageIndex) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(imageSelected(image));
|
if (alt) {
|
||||||
|
dispatch(imageToCompareChanged(image));
|
||||||
|
} else {
|
||||||
|
dispatch(imageSelected(image));
|
||||||
|
}
|
||||||
scrollToImage(image.image_name, index);
|
scrollToImage(image.image_name, index);
|
||||||
},
|
},
|
||||||
[dispatch, lastSelectedImageIndex, data]
|
[data, lastSelectedImageIndex, dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
||||||
@ -162,21 +176,41 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
|||||||
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
||||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||||
|
|
||||||
const handleLeftImage = useCallback(() => {
|
const handleLeftImage = useCallback(
|
||||||
handleNavigation('left');
|
(alt?: boolean) => {
|
||||||
}, [handleNavigation]);
|
handleNavigation('left', alt);
|
||||||
|
},
|
||||||
|
[handleNavigation]
|
||||||
|
);
|
||||||
|
|
||||||
const handleRightImage = useCallback(() => {
|
const handleRightImage = useCallback(
|
||||||
handleNavigation('right');
|
(alt?: boolean) => {
|
||||||
}, [handleNavigation]);
|
handleNavigation('right', alt);
|
||||||
|
},
|
||||||
|
[handleNavigation]
|
||||||
|
);
|
||||||
|
|
||||||
const handleUpImage = useCallback(() => {
|
const handleUpImage = useCallback(
|
||||||
handleNavigation('up');
|
(alt?: boolean) => {
|
||||||
}, [handleNavigation]);
|
handleNavigation('up', alt);
|
||||||
|
},
|
||||||
|
[handleNavigation]
|
||||||
|
);
|
||||||
|
|
||||||
const handleDownImage = useCallback(() => {
|
const handleDownImage = useCallback(
|
||||||
handleNavigation('down');
|
(alt?: boolean) => {
|
||||||
}, [handleNavigation]);
|
handleNavigation('down', alt);
|
||||||
|
},
|
||||||
|
[handleNavigation]
|
||||||
|
);
|
||||||
|
|
||||||
|
const nextImage = useCallback(() => {
|
||||||
|
handleRightImage();
|
||||||
|
}, [handleRightImage]);
|
||||||
|
|
||||||
|
const prevImage = useCallback(() => {
|
||||||
|
handleLeftImage();
|
||||||
|
}, [handleLeftImage]);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
handleLeftImage,
|
handleLeftImage,
|
||||||
@ -186,5 +220,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
|||||||
isOnFirstImage,
|
isOnFirstImage,
|
||||||
isOnLastImage,
|
isOnLastImage,
|
||||||
areImagesBelowCurrent,
|
areImagesBelowCurrent,
|
||||||
|
nextImage,
|
||||||
|
prevImage,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -36,6 +36,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
|||||||
shiftKey: e.shiftKey,
|
shiftKey: e.shiftKey,
|
||||||
ctrlKey: e.ctrlKey,
|
ctrlKey: e.ctrlKey,
|
||||||
metaKey: e.metaKey,
|
metaKey: e.metaKey,
|
||||||
|
altKey: e.altKey,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
@ -6,7 +6,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
|
|||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
import type { BoardId, GalleryState, GalleryView } from './types';
|
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
|
||||||
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
||||||
|
|
||||||
const initialGalleryState: GalleryState = {
|
const initialGalleryState: GalleryState = {
|
||||||
@ -22,6 +22,9 @@ const initialGalleryState: GalleryState = {
|
|||||||
limit: INITIAL_IMAGE_LIMIT,
|
limit: INITIAL_IMAGE_LIMIT,
|
||||||
offset: 0,
|
offset: 0,
|
||||||
isImageViewerOpen: true,
|
isImageViewerOpen: true,
|
||||||
|
imageToCompare: null,
|
||||||
|
comparisonMode: 'slider',
|
||||||
|
comparisonFit: 'fill',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const gallerySlice = createSlice({
|
export const gallerySlice = createSlice({
|
||||||
@ -34,6 +37,28 @@ export const gallerySlice = createSlice({
|
|||||||
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
|
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||||
state.selection = uniqBy(action.payload, (i) => i.image_name);
|
state.selection = uniqBy(action.payload, (i) => i.image_name);
|
||||||
},
|
},
|
||||||
|
imageToCompareChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||||
|
state.imageToCompare = action.payload;
|
||||||
|
if (action.payload) {
|
||||||
|
state.isImageViewerOpen = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
comparisonModeChanged: (state, action: PayloadAction<ComparisonMode>) => {
|
||||||
|
state.comparisonMode = action.payload;
|
||||||
|
},
|
||||||
|
comparisonModeCycled: (state) => {
|
||||||
|
switch (state.comparisonMode) {
|
||||||
|
case 'slider':
|
||||||
|
state.comparisonMode = 'side-by-side';
|
||||||
|
break;
|
||||||
|
case 'side-by-side':
|
||||||
|
state.comparisonMode = 'hover';
|
||||||
|
break;
|
||||||
|
case 'hover':
|
||||||
|
state.comparisonMode = 'slider';
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
|
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAutoSwitch = action.payload;
|
state.shouldAutoSwitch = action.payload;
|
||||||
},
|
},
|
||||||
@ -79,6 +104,16 @@ export const gallerySlice = createSlice({
|
|||||||
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
|
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.isImageViewerOpen = action.payload;
|
state.isImageViewerOpen = action.payload;
|
||||||
},
|
},
|
||||||
|
comparedImagesSwapped: (state) => {
|
||||||
|
if (state.imageToCompare) {
|
||||||
|
const oldSelection = state.selection;
|
||||||
|
state.selection = [state.imageToCompare];
|
||||||
|
state.imageToCompare = oldSelection[0] ?? null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
|
||||||
|
state.comparisonFit = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||||
@ -117,6 +152,11 @@ export const {
|
|||||||
moreImagesLoaded,
|
moreImagesLoaded,
|
||||||
alwaysShowImageSizeBadgeChanged,
|
alwaysShowImageSizeBadgeChanged,
|
||||||
isImageViewerOpenChanged,
|
isImageViewerOpenChanged,
|
||||||
|
imageToCompareChanged,
|
||||||
|
comparisonModeChanged,
|
||||||
|
comparedImagesSwapped,
|
||||||
|
comparisonFitChanged,
|
||||||
|
comparisonModeCycled,
|
||||||
} = gallerySlice.actions;
|
} = gallerySlice.actions;
|
||||||
|
|
||||||
const isAnyBoardDeleted = isAnyOf(
|
const isAnyBoardDeleted = isAnyOf(
|
||||||
@ -138,5 +178,13 @@ export const galleryPersistConfig: PersistConfig<GalleryState> = {
|
|||||||
name: gallerySlice.name,
|
name: gallerySlice.name,
|
||||||
initialState: initialGalleryState,
|
initialState: initialGalleryState,
|
||||||
migrate: migrateGalleryState,
|
migrate: migrateGalleryState,
|
||||||
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'offset', 'limit', 'isImageViewerOpen'],
|
persistDenylist: [
|
||||||
|
'selection',
|
||||||
|
'selectedBoardId',
|
||||||
|
'galleryView',
|
||||||
|
'offset',
|
||||||
|
'limit',
|
||||||
|
'isImageViewerOpen',
|
||||||
|
'imageToCompare',
|
||||||
|
],
|
||||||
};
|
};
|
||||||
|
@ -7,6 +7,8 @@ export const IMAGE_LIMIT = 20;
|
|||||||
|
|
||||||
export type GalleryView = 'images' | 'assets';
|
export type GalleryView = 'images' | 'assets';
|
||||||
export type BoardId = 'none' | (string & Record<never, never>);
|
export type BoardId = 'none' | (string & Record<never, never>);
|
||||||
|
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
|
||||||
|
export type ComparisonFit = 'contain' | 'fill';
|
||||||
|
|
||||||
export type GalleryState = {
|
export type GalleryState = {
|
||||||
selection: ImageDTO[];
|
selection: ImageDTO[];
|
||||||
@ -20,5 +22,8 @@ export type GalleryState = {
|
|||||||
offset: number;
|
offset: number;
|
||||||
limit: number;
|
limit: number;
|
||||||
alwaysShowImageSizeBadge: boolean;
|
alwaysShowImageSizeBadge: boolean;
|
||||||
|
imageToCompare: ImageDTO | null;
|
||||||
|
comparisonMode: ComparisonMode;
|
||||||
|
comparisonFit: ComparisonFit;
|
||||||
isImageViewerOpen: boolean;
|
isImageViewerOpen: boolean;
|
||||||
};
|
};
|
||||||
|
@ -19,7 +19,7 @@ import {
|
|||||||
redo,
|
redo,
|
||||||
undo,
|
undo,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
import { $flow, $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||||
import type { CSSProperties, MouseEvent } from 'react';
|
import type { CSSProperties, MouseEvent } from 'react';
|
||||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||||
@ -68,6 +68,7 @@ export const Flow = memo(() => {
|
|||||||
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
||||||
const edges = useAppSelector((s) => s.nodes.present.edges);
|
const edges = useAppSelector((s) => s.nodes.present.edges);
|
||||||
const viewport = useStore($viewport);
|
const viewport = useStore($viewport);
|
||||||
|
const needsFit = useStore($needsFit);
|
||||||
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
|
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
|
||||||
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
|
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
|
||||||
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
||||||
@ -92,8 +93,16 @@ export const Flow = memo(() => {
|
|||||||
const onNodesChange: OnNodesChange = useCallback(
|
const onNodesChange: OnNodesChange = useCallback(
|
||||||
(nodeChanges) => {
|
(nodeChanges) => {
|
||||||
dispatch(nodesChanged(nodeChanges));
|
dispatch(nodesChanged(nodeChanges));
|
||||||
|
const flow = $flow.get();
|
||||||
|
if (!flow) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (needsFit) {
|
||||||
|
$needsFit.set(false);
|
||||||
|
flow.fitView();
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch, needsFit]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onEdgesChange: OnEdgesChange = useCallback(
|
const onEdgesChange: OnEdgesChange = useCallback(
|
||||||
|
@ -15,27 +15,20 @@ const ViewportControls = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
// const shouldShowFieldTypeLegend = useAppSelector(
|
|
||||||
// (s) => s.nodes.present.shouldShowFieldTypeLegend
|
|
||||||
// );
|
|
||||||
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
|
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
|
||||||
|
|
||||||
const handleClickedZoomIn = useCallback(() => {
|
const handleClickedZoomIn = useCallback(() => {
|
||||||
zoomIn();
|
zoomIn({ duration: 300 });
|
||||||
}, [zoomIn]);
|
}, [zoomIn]);
|
||||||
|
|
||||||
const handleClickedZoomOut = useCallback(() => {
|
const handleClickedZoomOut = useCallback(() => {
|
||||||
zoomOut();
|
zoomOut({ duration: 300 });
|
||||||
}, [zoomOut]);
|
}, [zoomOut]);
|
||||||
|
|
||||||
const handleClickedFitView = useCallback(() => {
|
const handleClickedFitView = useCallback(() => {
|
||||||
fitView();
|
fitView({ duration: 300 });
|
||||||
}, [fitView]);
|
}, [fitView]);
|
||||||
|
|
||||||
// const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
|
||||||
// dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
|
||||||
// }, [shouldShowFieldTypeLegend, dispatch]);
|
|
||||||
|
|
||||||
const handleClickedToggleMiniMapPanel = useCallback(() => {
|
const handleClickedToggleMiniMapPanel = useCallback(() => {
|
||||||
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
|
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
|
||||||
}, [shouldShowMinimapPanel, dispatch]);
|
}, [shouldShowMinimapPanel, dispatch]);
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import 'reactflow/dist/style.css';
|
import 'reactflow/dist/style.css';
|
||||||
|
|
||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
|
||||||
import QueueControls from 'features/queue/components/QueueControls';
|
import QueueControls from 'features/queue/components/QueueControls';
|
||||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||||
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
|
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
|
||||||
@ -21,14 +19,8 @@ import { WorkflowName } from './WorkflowName';
|
|||||||
|
|
||||||
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
|
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
|
||||||
return {
|
|
||||||
mode: workflow.mode,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const NodeEditorPanelGroup = () => {
|
const NodeEditorPanelGroup = () => {
|
||||||
const { mode } = useAppSelector(selector);
|
const mode = useAppSelector((s) => s.workflow.mode);
|
||||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||||
const panelStorage = usePanelStorage();
|
const panelStorage = usePanelStorage();
|
||||||
|
|
||||||
|
@ -1,20 +1,12 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
|
||||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||||
|
|
||||||
import { ModeToggle } from './ModeToggle';
|
import { ModeToggle } from './ModeToggle';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
|
||||||
return {
|
|
||||||
mode: workflow.mode,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
export const WorkflowMenu = () => {
|
export const WorkflowMenu = () => {
|
||||||
const { mode } = useAppSelector(selector);
|
const mode = useAppSelector((s) => s.workflow.mode);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap="2" alignItems="center">
|
<Flex gap="2" alignItems="center">
|
||||||
|
@ -11,8 +11,7 @@ import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
|||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { ImageOutput } from 'services/api/types';
|
import type { AnyInvocationOutput, ImageOutput } from 'services/api/types';
|
||||||
import type { AnyResult } from 'services/events/types';
|
|
||||||
|
|
||||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||||
|
|
||||||
@ -66,4 +65,4 @@ const InspectorOutputsTab = () => {
|
|||||||
|
|
||||||
export default memo(InspectorOutputsTab);
|
export default memo(InspectorOutputsTab);
|
||||||
|
|
||||||
const getKey = (result: AnyResult, i: number) => `${result.type}-${i}`;
|
const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`;
|
||||||
|
@ -2,3 +2,4 @@ import { atom } from 'nanostores';
|
|||||||
import type { ReactFlowInstance } from 'reactflow';
|
import type { ReactFlowInstance } from 'reactflow';
|
||||||
|
|
||||||
export const $flow = atom<ReactFlowInstance | null>(null);
|
export const $flow = atom<ReactFlowInstance | null>(null);
|
||||||
|
export const $needsFit = atom<boolean>(true);
|
||||||
|
@ -144,5 +144,4 @@ const zImageOutput = z.object({
|
|||||||
type: z.literal('image_output'),
|
type: z.literal('image_output'),
|
||||||
});
|
});
|
||||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||||
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
|
|
||||||
// #endregion
|
// #endregion
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import type { NodesState } from 'features/nodes/store/types';
|
import type { NodesState } from 'features/nodes/store/types';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { omit, reduce } from 'lodash-es';
|
import { omit, reduce } from 'lodash-es';
|
||||||
import type { Graph } from 'services/api/types';
|
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||||
import type { AnyInvocation } from 'services/events/types';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Box } from '@invoke-ai/ui-library';
|
import { Box } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { ImageViewerWorkflows } from 'features/gallery/components/ImageViewer/ImageViewerWorkflows';
|
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
|
||||||
|
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
||||||
import NodeEditor from 'features/nodes/components/NodeEditor';
|
import NodeEditor from 'features/nodes/components/NodeEditor';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { ReactFlowProvider } from 'reactflow';
|
import { ReactFlowProvider } from 'reactflow';
|
||||||
@ -10,7 +11,8 @@ const NodesTab = () => {
|
|||||||
if (mode === 'view') {
|
if (mode === 'view') {
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||||
<ImageViewerWorkflows />
|
<ImageViewer />
|
||||||
|
<ImageComparisonDroppable />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
import { Box } from '@invoke-ai/ui-library';
|
import { Box } from '@invoke-ai/ui-library';
|
||||||
import { ControlLayersEditor } from 'features/controlLayers/components/ControlLayersEditor';
|
import { ControlLayersEditor } from 'features/controlLayers/components/ControlLayersEditor';
|
||||||
|
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
|
||||||
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
||||||
|
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
const TextToImageTab = () => {
|
const TextToImageTab = () => {
|
||||||
|
const imageViewer = useImageViewer();
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||||
<ControlLayersEditor />
|
<ControlLayersEditor />
|
||||||
<ImageViewer />
|
{imageViewer.isOpen && <ImageViewer />}
|
||||||
|
<ImageComparisonDroppable />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -41,7 +41,7 @@ const UnifiedCanvasTab = () => {
|
|||||||
>
|
>
|
||||||
<IAICanvasToolbar />
|
<IAICanvasToolbar />
|
||||||
<IAICanvas />
|
<IAICanvas />
|
||||||
{isValidDrop(droppableData, active) && (
|
{isValidDrop(droppableData, active?.data.current) && (
|
||||||
<IAIDropOverlay isOver={isOver} label={t('toast.setCanvasInitialImage')} />
|
<IAIDropOverlay isOver={isOver} label={t('toast.setCanvasInitialImage')} />
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
File diff suppressed because one or more lines are too long
@ -122,7 +122,6 @@ export type ModelInstallStatus = S['InstallStatus'];
|
|||||||
// Graphs
|
// Graphs
|
||||||
export type Graph = S['Graph'];
|
export type Graph = S['Graph'];
|
||||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||||
export type GraphExecutionState = S['GraphExecutionState'];
|
|
||||||
export type Batch = S['Batch'];
|
export type Batch = S['Batch'];
|
||||||
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
|
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
|
||||||
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
|
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
|
||||||
@ -132,14 +131,14 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
|
|||||||
type KeysOfUnion<T> = T extends T ? keyof T : never;
|
type KeysOfUnion<T> = T extends T ? keyof T : never;
|
||||||
|
|
||||||
export type AnyInvocation = Exclude<
|
export type AnyInvocation = Exclude<
|
||||||
Graph['nodes'][string],
|
NonNullable<S['Graph']['nodes']>[string],
|
||||||
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
|
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
|
||||||
>;
|
>;
|
||||||
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
|
export type AnyInvocationIncMetadata = NonNullable<S['Graph']['nodes']>[string];
|
||||||
|
|
||||||
export type InvocationType = AnyInvocation['type'];
|
export type InvocationType = AnyInvocation['type'];
|
||||||
type InvocationOutputMap = S['InvocationOutputMap'];
|
type InvocationOutputMap = S['InvocationOutputMap'];
|
||||||
type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
||||||
|
|
||||||
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
|
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
|
||||||
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
|
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
|
||||||
|
@ -1,21 +1,12 @@
|
|||||||
import type { Graph, GraphExecutionState, S } from 'services/api/types';
|
import type { S } from 'services/api/types';
|
||||||
|
|
||||||
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
|
|
||||||
|
|
||||||
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
|
|
||||||
|
|
||||||
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
||||||
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
||||||
|
|
||||||
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation };
|
export type InvocationStartedEvent = S['InvocationStartedEvent'];
|
||||||
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & {
|
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
|
||||||
invocation: AnyInvocation;
|
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
|
||||||
};
|
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
||||||
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
|
|
||||||
result: AnyResult;
|
|
||||||
invocation: AnyInvocation;
|
|
||||||
};
|
|
||||||
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
|
|
||||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
||||||
|
|
||||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||||
|
@ -55,10 +55,10 @@ dependencies = [
|
|||||||
|
|
||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.0",
|
"fastapi-events==0.11.0",
|
||||||
"fastapi==0.110.0",
|
"fastapi==0.111.0",
|
||||||
"huggingface-hub==0.23.1",
|
"huggingface-hub==0.23.1",
|
||||||
"pydantic-settings==2.2.1",
|
"pydantic-settings==2.2.1",
|
||||||
"pydantic==2.6.3",
|
"pydantic==2.7.2",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
"uvicorn[standard]==0.28.0",
|
"uvicorn[standard]==0.28.0",
|
||||||
|
|
||||||
|
@ -7,9 +7,10 @@ def main():
|
|||||||
# Change working directory to the repo root
|
# Change working directory to the repo root
|
||||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
from invokeai.app.api_app import custom_openapi
|
from invokeai.app.api_app import app
|
||||||
|
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||||
|
|
||||||
schema = custom_openapi()
|
schema = get_openapi_func(app)()
|
||||||
json.dump(schema, sys.stdout, indent=2)
|
json.dump(schema, sys.stdout, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
from pydantic.json_schema import models_json_schema
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
|
|||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
_ = Graph.model_json_schema()
|
models_json_schema([(Graph, "serialization")])
|
||||||
|
Loading…
Reference in New Issue
Block a user