mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/feat/simple-mm2-api
This commit is contained in:
commit
2276f327e5
4
Makefile
4
Makefile
@ -18,6 +18,7 @@ help:
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
@ -70,3 +71,6 @@ installer-zip:
|
||||
tag-release:
|
||||
cd installer && ./tag_release.sh
|
||||
|
||||
# Generate the OpenAPI Schema for the app
|
||||
openapi:
|
||||
python scripts/generate_openapi_schema.py
|
||||
|
@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
|
||||
|
||||
Check the [configuration docs] for more detail about the settings and how to specify them.
|
||||
|
||||
## `ModuleNotFoundError: No module named 'controlnet_aux'`
|
||||
|
||||
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
|
||||
|
||||
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
|
||||
|
||||
- Run the Invoke launcher
|
||||
- Choose the developer console option
|
||||
- Run this command: `pip cache remove controlnet_aux`
|
||||
- Close the terminal window
|
||||
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
|
||||
|
||||
## Out of Memory Issues
|
||||
|
||||
The models are large, VRAM is expensive, and you may find yourself
|
||||
|
@ -3,9 +3,7 @@ import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
@ -13,11 +11,9 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
@ -45,11 +39,6 @@ from .api.routers import (
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = {}
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# Add all event schemas
|
||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
if "$defs" in json_schema:
|
||||
for schema_key, schema in json_schema["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
||||
del json_schema["$defs"]
|
||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
|
@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "output"
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
@classmethod
|
||||
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
from math import floor
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
@ -33,6 +33,7 @@ class EventBase(BaseModel):
|
||||
A timestamp is automatically added to the event when it is created.
|
||||
"""
|
||||
|
||||
__event_name__: ClassVar[str]
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
@ -97,7 +98,7 @@ class InvocationEventBase(QueueItemEventBase):
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
|
||||
|
||||
@ -108,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
@ -135,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
invocation: AnyInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
@ -173,11 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
|
||||
__event_name__ = "invocation_complete"
|
||||
|
||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
||||
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
||||
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -206,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
invocation: AnyInvocation,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
|
@ -2,18 +2,19 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic.fields import Field
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from invokeai.app.invocations import * # noqa: F401 F403
|
||||
@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||
|
||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||
#
|
||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||
# invocations will cause a graph to fail if they are used.
|
||||
#
|
||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||
# the generated schema as options for the `nodes` field.
|
||||
#
|
||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||
# expected.
|
||||
#
|
||||
# You might be tempted to do something like this:
|
||||
#
|
||||
# ```py
|
||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||
# delattr(cloned_model, "validate_nodes")
|
||||
# cloned_model.model_rebuild(force=True)
|
||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||
# ```
|
||||
#
|
||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||
nodes: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The nodes in this graph")
|
||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||
|
||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||
|
||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state")
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution"
|
||||
)
|
||||
results: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions")
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
description="The map of prepared nodes to original graph nodes"
|
||||
)
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||
description="The map of original graph nodes to prepared nodes"
|
||||
)
|
||||
|
||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
|
116
invokeai/app/util/custom_openapi.py
Normal file
116
invokeai/app/util/custom_openapi.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
|
||||
|
||||
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
|
||||
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
|
||||
component_schema."""
|
||||
|
||||
defs = component_schema.pop("$defs", {})
|
||||
for schema_key, json_schema in defs.items():
|
||||
if schema_key in openapi_schema["components"]["schemas"]:
|
||||
continue
|
||||
openapi_schema["components"]["schemas"][schema_key] = json_schema
|
||||
|
||||
|
||||
def get_openapi_func(
|
||||
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||
) -> Callable[[], dict[str, Any]]:
|
||||
"""Gets the OpenAPI schema generator function.
|
||||
|
||||
Args:
|
||||
app (FastAPI): The FastAPI app to generate the schema for.
|
||||
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
|
||||
generated schema before returning it. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
|
||||
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
|
||||
matches FastAPI's default schema generation caching.
|
||||
"""
|
||||
|
||||
def openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# We'll create a map of invocation type to output schema to make some types simpler on the client.
|
||||
invocation_output_map_properties: dict[str, Any] = {}
|
||||
invocation_output_map_required: list[str] = []
|
||||
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in BaseInvocationOutput.get_outputs():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||
# property, so we'll just do it all manually.
|
||||
for invocation in BaseInvocation.get_invocations():
|
||||
json_schema = invocation.model_json_schema(
|
||||
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
output_title = invocation.get_output_annotation().__name__
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
|
||||
json_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
|
||||
|
||||
# Add this invocation and its output to the output map
|
||||
invocation_type = invocation.get_type()
|
||||
invocation_output_map_properties[invocation_type] = json_schema["output"]
|
||||
invocation_output_map_required.append(invocation_type)
|
||||
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": invocation_output_map_properties,
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
|
||||
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
|
||||
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
|
||||
# `models_json_schema` seems to work fine.
|
||||
additional_models = [
|
||||
*EventBase.get_events(),
|
||||
UIConfigBase,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
ModelIdentifierField,
|
||||
ProgressImage,
|
||||
]
|
||||
|
||||
additional_schemas = models_json_schema(
|
||||
[(m, "serialization") for m in additional_models],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
|
||||
move_defs_to_top_level(openapi_schema, additional_schemas[1])
|
||||
|
||||
if post_transform is not None:
|
||||
openapi_schema = post_transform(openapi_schema)
|
||||
|
||||
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
return openapi
|
@ -53,5 +53,5 @@ class ModelLocker(ModelLockerBase):
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Textual Inversion wrapper class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
|
||||
return result
|
||||
|
||||
|
||||
# no type hints for BaseTextualInversionManager?
|
||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = {}
|
||||
self.pad_tokens: dict[int, list[int]] = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
|
||||
|
||||
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
|
||||
mapping of tokens to token_ids:
|
||||
```
|
||||
<ti_dog>: 49408
|
||||
<ti_dog-!pad-1>: 49409
|
||||
<ti_dog-!pad-2>: 49410
|
||||
<ti_dog-!pad-3>: 49411
|
||||
```
|
||||
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
|
||||
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
|
||||
"""
|
||||
# Short circuit if there are no pad tokens to save a little time.
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
|
||||
# this assumption here.
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
# Expand any TI tokens to their corresponding pad tokens.
|
||||
new_token_ids: list[int] = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
# Do not exceed the max model input size. The -2 here is compensating for
|
||||
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
|
||||
max_length = self.tokenizer.model_max_length - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
|
||||
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
|
||||
# prompts.
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
||||
|
@ -148,6 +148,8 @@
|
||||
"viewingDesc": "Review images in a large gallery view",
|
||||
"editing": "Editing",
|
||||
"editingDesc": "Edit on the Control Layers canvas",
|
||||
"comparing": "Comparing",
|
||||
"comparingDesc": "Comparing two images",
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled"
|
||||
},
|
||||
@ -375,7 +377,23 @@
|
||||
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
||||
"bulkDownloadFailed": "Download Failed",
|
||||
"problemDeletingImages": "Problem Deleting Images",
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted"
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||
"viewerImage": "Viewer Image",
|
||||
"compareImage": "Compare Image",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"selectForCompare": "Select for Compare",
|
||||
"selectAnImageToCompare": "Select an Image to Compare",
|
||||
"slider": "Slider",
|
||||
"sideBySide": "Side-by-Side",
|
||||
"hover": "Hover",
|
||||
"swapImages": "Swap Images",
|
||||
"compareOptions": "Comparison Options",
|
||||
"stretchToFit": "Stretch to Fit",
|
||||
"exitCompare": "Exit Compare",
|
||||
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Search Hotkeys",
|
||||
|
@ -19,6 +19,13 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
return extendTheme({
|
||||
..._theme,
|
||||
direction,
|
||||
shadows: {
|
||||
..._theme.shadows,
|
||||
selectedForCompare:
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
||||
hoverSelectedForCompare:
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
||||
},
|
||||
});
|
||||
}, [direction]);
|
||||
|
||||
|
@ -13,7 +13,6 @@ import {
|
||||
isControlAdapterLayer,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@ -139,7 +138,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
|
||||
// We still have to check the output type
|
||||
assert(
|
||||
isImageOutput(invocationCompleteAction.payload.data.result),
|
||||
invocationCompleteAction.payload.data.result.type === 'image_output',
|
||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||
);
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
@ -9,7 +9,6 @@ import {
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
||||
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
||||
// Wait for the ImageDTO to be received
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
@ -11,6 +11,7 @@ export const galleryImageClicked = createAction<{
|
||||
shiftKey: boolean;
|
||||
ctrlKey: boolean;
|
||||
metaKey: boolean;
|
||||
altKey: boolean;
|
||||
}>('gallery/imageClicked');
|
||||
|
||||
/**
|
||||
@ -28,7 +29,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
startAppListening({
|
||||
actionCreator: galleryImageClicked,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
@ -41,7 +42,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (shiftKey) {
|
||||
if (altKey) {
|
||||
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
}
|
||||
} else if (shiftKey) {
|
||||
const rangeEndImageName = imageDTO.image_name;
|
||||
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
||||
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
||||
|
@ -14,7 +14,8 @@ import {
|
||||
rgLayerIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@ -30,6 +31,9 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('dnd');
|
||||
const { activeData, overData } = action.payload;
|
||||
if (!isValidDrop(overData, activeData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||
log.debug({ activeData, overData }, 'Image dropped');
|
||||
@ -50,6 +54,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -182,24 +187,18 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO
|
||||
* Image selection dropped on node image collection field
|
||||
* Image selected for compare
|
||||
*/
|
||||
// if (
|
||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||
// activeData.payloadType === 'IMAGE_DTO' &&
|
||||
// activeData.payload.imageDTO
|
||||
// ) {
|
||||
// const { fieldName, nodeId } = overData.context;
|
||||
// dispatch(
|
||||
// fieldValueChanged({
|
||||
// nodeId,
|
||||
// fieldName,
|
||||
// value: [activeData.payload.imageDTO],
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
if (
|
||||
overData.actionType === 'SELECT_FOR_COMPARE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on user board
|
||||
|
@ -11,7 +11,6 @@ import {
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
|
||||
const { result, invocation_source_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||
const { image_name } = data.result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
|
||||
|
@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||
@ -65,9 +65,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
|
||||
});
|
||||
}
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
$needsFit.set(true);
|
||||
} catch (e) {
|
||||
if (e instanceof WorkflowVersionError) {
|
||||
// The workflow version was not recognized in the valid list of versions
|
||||
|
@ -35,6 +35,7 @@ type IAIDndImageProps = FlexProps & {
|
||||
draggableData?: TypesafeDraggableData;
|
||||
dropLabel?: ReactNode;
|
||||
isSelected?: boolean;
|
||||
isSelectedForCompare?: boolean;
|
||||
thumbnail?: boolean;
|
||||
noContentFallback?: ReactElement;
|
||||
useThumbailFallback?: boolean;
|
||||
@ -61,6 +62,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
draggableData,
|
||||
dropLabel,
|
||||
isSelected = false,
|
||||
isSelectedForCompare = false,
|
||||
thumbnail = false,
|
||||
noContentFallback = defaultNoContentFallback,
|
||||
uploadElement = defaultUploadElement,
|
||||
@ -165,7 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
|
||||
<SelectionOverlay
|
||||
isSelected={isSelected}
|
||||
isSelectedForCompare={isSelectedForCompare}
|
||||
isHovered={withHoverOverlay ? isHovered : false}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
|
@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
pointerEvents={active ? 'auto' : 'none'}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||
</AnimatePresence>
|
||||
</Box>
|
||||
);
|
||||
|
@ -3,10 +3,17 @@ import { memo, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isSelected: boolean;
|
||||
isSelectedForCompare: boolean;
|
||||
isHovered: boolean;
|
||||
};
|
||||
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
|
||||
const shadow = useMemo(() => {
|
||||
if (isSelectedForCompare && isHovered) {
|
||||
return 'hoverSelectedForCompare';
|
||||
}
|
||||
if (isSelectedForCompare && !isHovered) {
|
||||
return 'selectedForCompare';
|
||||
}
|
||||
if (isSelected && isHovered) {
|
||||
return 'hoverSelected';
|
||||
}
|
||||
@ -17,7 +24,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
return 'hoverUnselected';
|
||||
}
|
||||
return undefined;
|
||||
}, [isHovered, isSelected]);
|
||||
}, [isHovered, isSelected, isSelectedForCompare]);
|
||||
return (
|
||||
<Box
|
||||
className="selection-box"
|
||||
@ -27,7 +34,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
bottom={0}
|
||||
insetInlineStart={0}
|
||||
borderRadius="base"
|
||||
opacity={isSelected ? 1 : 0.7}
|
||||
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
pointerEvents="none"
|
||||
|
21
invokeai/frontend/web/src/common/hooks/useBoolean.ts
Normal file
21
invokeai/frontend/web/src/common/hooks/useBoolean.ts
Normal file
@ -0,0 +1,21 @@
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
|
||||
export const useBoolean = (initialValue: boolean) => {
|
||||
const [isTrue, set] = useState(initialValue);
|
||||
const setTrue = useCallback(() => set(true), []);
|
||||
const setFalse = useCallback(() => set(false), []);
|
||||
const toggle = useCallback(() => set((v) => !v), []);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
isTrue,
|
||||
set,
|
||||
setTrue,
|
||||
setFalse,
|
||||
toggle,
|
||||
}),
|
||||
[isTrue, set, setTrue, setFalse, toggle]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
@ -1,3 +1,7 @@
|
||||
export const stopPropagation = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
export const preventDefault = (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
};
|
||||
|
@ -1,7 +1,13 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
AnyInvocation,
|
||||
BaseModelType,
|
||||
ControlNetModelConfig,
|
||||
ImageDTO,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zId = z.string().min(1);
|
||||
@ -147,7 +153,7 @@ const zBeginEndStepPct = z
|
||||
|
||||
const zControlAdapterBase = z.object({
|
||||
id: zId,
|
||||
weight: z.number().gte(0).lte(1),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
image: zImageWithDims.nullable(),
|
||||
processedImage: zImageWithDims.nullable(),
|
||||
processorConfig: zProcessorConfig.nullable(),
|
||||
@ -183,7 +189,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar
|
||||
export const zIPAdapterConfigV2 = z.object({
|
||||
id: zId,
|
||||
type: z.literal('ip_adapter'),
|
||||
weight: z.number().gte(0).lte(1),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
method: zIPMethodV2,
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
@ -216,10 +222,7 @@ type ProcessorData<T extends ProcessorTypeV2> = {
|
||||
labelTKey: string;
|
||||
descriptionTKey: string;
|
||||
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
|
||||
buildNode(
|
||||
image: ImageWithDims,
|
||||
config: Extract<ProcessorConfig, { type: T }>
|
||||
): Extract<Graph['nodes'][string], { type: T }>;
|
||||
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
|
||||
};
|
||||
|
||||
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
|
||||
|
@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
||||
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
const STAGE_BG_DATAURL =
|
||||
export const STAGE_BG_DATAURL =
|
||||
'';
|
||||
|
||||
const mapId = (object: { id: string }) => object.id;
|
||||
|
@ -18,7 +18,7 @@ type BaseDropData = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
type CurrentImageDropData = BaseDropData & {
|
||||
export type CurrentImageDropData = BaseDropData & {
|
||||
actionType: 'SET_CURRENT_IMAGE';
|
||||
};
|
||||
|
||||
@ -79,6 +79,14 @@ export type RemoveFromBoardDropData = BaseDropData & {
|
||||
actionType: 'REMOVE_FROM_BOARD';
|
||||
};
|
||||
|
||||
export type SelectForCompareDropData = BaseDropData & {
|
||||
actionType: 'SELECT_FOR_COMPARE';
|
||||
context: {
|
||||
firstImageName?: string | null;
|
||||
secondImageName?: string | null;
|
||||
};
|
||||
};
|
||||
|
||||
export type TypesafeDroppableData =
|
||||
| CurrentImageDropData
|
||||
| ControlAdapterDropData
|
||||
@ -89,7 +97,8 @@ export type TypesafeDroppableData =
|
||||
| CALayerImageDropData
|
||||
| IPALayerImageDropData
|
||||
| RGLayerIPAdapterImageDropData
|
||||
| IILayerImageDropData;
|
||||
| IILayerImageDropData
|
||||
| SelectForCompareDropData;
|
||||
|
||||
type BaseDragData = {
|
||||
id: string;
|
||||
@ -134,7 +143,7 @@ export type UseDraggableTypesafeReturnValue = Omit<ReturnType<typeof useOriginal
|
||||
over: TypesafeOver | null;
|
||||
};
|
||||
|
||||
export interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||
}
|
||||
|
||||
|
@ -1,14 +1,14 @@
|
||||
import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
|
||||
export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => {
|
||||
if (!overData || !active?.data.current) {
|
||||
export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => {
|
||||
if (!overData || !activeData) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { actionType } = overData;
|
||||
const { payloadType } = active.data.current;
|
||||
const { payloadType } = activeData;
|
||||
|
||||
if (overData.id === active.data.current.id) {
|
||||
if (overData.id === activeData.id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -29,6 +29,8 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SELECT_FOR_COMPARE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'ADD_TO_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
@ -40,7 +42,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const { imageDTO } = activeData.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
const destinationBoard = overData.context.boardId;
|
||||
|
||||
@ -49,7 +51,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
||||
|
||||
if (payloadType === 'GALLERY_SELECTION') {
|
||||
// Assume all images are on the same board - this is true for the moment
|
||||
const currentBoard = active.data.current.payload.boardId;
|
||||
const currentBoard = activeData.payload.boardId;
|
||||
const destinationBoard = overData.context.boardId;
|
||||
return currentBoard !== destinationBoard;
|
||||
}
|
||||
@ -67,14 +69,14 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const { imageDTO } = activeData.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
if (payloadType === 'GALLERY_SELECTION') {
|
||||
const currentBoard = active.data.current.payload.boardId;
|
||||
const currentBoard = activeData.payload.boardId;
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
|
@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
</Flex>
|
||||
)}
|
||||
{isSelectedForAutoAdd && <AutoAddIcon />}
|
||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
||||
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
|
||||
<Flex
|
||||
position="absolute"
|
||||
bottom={0}
|
||||
|
@ -117,7 +117,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
>
|
||||
{boardName}
|
||||
</Flex>
|
||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
||||
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
|
||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
|
@ -10,6 +10,7 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@ -27,6 +28,7 @@ import {
|
||||
PiDownloadSimpleBold,
|
||||
PiFlowArrowBold,
|
||||
PiFoldersBold,
|
||||
PiImagesBold,
|
||||
PiPlantBold,
|
||||
PiQuotesBold,
|
||||
PiShareFatBold,
|
||||
@ -44,6 +46,7 @@ type SingleSelectionMenuItemsProps = {
|
||||
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const { imageDTO } = props;
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isCanvasEnabled = useFeatureStatus('canvas');
|
||||
@ -117,6 +120,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
downloadImage(imageDTO.image_url, imageDTO.image_name);
|
||||
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
|
||||
|
||||
const handleSelectImageForCompare = useCallback(() => {
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
|
||||
@ -130,6 +137,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
|
||||
{t('parameters.downloadImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiImagesBold />} isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}>
|
||||
{t('gallery.selectForCompare')}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem
|
||||
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
|
||||
|
@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
||||
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -46,6 +46,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
@ -105,6 +106,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
|
||||
const onDoubleClick = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
dispatch(imageToCompareChanged(null));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleMouseOut = useCallback(() => {
|
||||
@ -152,6 +154,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
imageDTO={imageDTO}
|
||||
draggableData={draggableData}
|
||||
isSelected={isSelected}
|
||||
isSelectedForCompare={isSelectedForCompare}
|
||||
minSize={0}
|
||||
imageSx={imageSx}
|
||||
isDropDisabled={true}
|
||||
|
@ -0,0 +1,140 @@
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButton,
|
||||
Kbd,
|
||||
ListItem,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeChanged,
|
||||
comparisonModeCycled,
|
||||
imageToCompareChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const CompareToolbar = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
|
||||
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||
const setComparisonModeSlider = useCallback(() => {
|
||||
dispatch(comparisonModeChanged('slider'));
|
||||
}, [dispatch]);
|
||||
const setComparisonModeSideBySide = useCallback(() => {
|
||||
dispatch(comparisonModeChanged('side-by-side'));
|
||||
}, [dispatch]);
|
||||
const setComparisonModeHover = useCallback(() => {
|
||||
dispatch(comparisonModeChanged('hover'));
|
||||
}, [dispatch]);
|
||||
const swapImages = useCallback(() => {
|
||||
dispatch(comparedImagesSwapped());
|
||||
}, [dispatch]);
|
||||
useHotkeys('c', swapImages, [swapImages]);
|
||||
const toggleComparisonFit = useCallback(() => {
|
||||
dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain'));
|
||||
}, [dispatch, comparisonFit]);
|
||||
const exitCompare = useCallback(() => {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
}, [dispatch]);
|
||||
useHotkeys('esc', exitCompare, [exitCompare]);
|
||||
const nextMode = useCallback(() => {
|
||||
dispatch(comparisonModeCycled());
|
||||
}, [dispatch]);
|
||||
useHotkeys('m', nextMode, [nextMode]);
|
||||
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<IconButton
|
||||
icon={<PiSwapBold />}
|
||||
aria-label={`${t('gallery.swapImages')} (C)`}
|
||||
tooltip={`${t('gallery.swapImages')} (C)`}
|
||||
onClick={swapImages}
|
||||
/>
|
||||
{comparisonMode !== 'side-by-side' && (
|
||||
<IconButton
|
||||
aria-label={t('gallery.stretchToFit')}
|
||||
tooltip={t('gallery.stretchToFit')}
|
||||
onClick={toggleComparisonFit}
|
||||
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
|
||||
variant="outline"
|
||||
icon={<PiArrowsOutBold />}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={4} justifyContent="center">
|
||||
<ButtonGroup variant="outline">
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeSlider}
|
||||
colorScheme={comparisonMode === 'slider' ? 'invokeBlue' : 'base'}
|
||||
>
|
||||
{t('gallery.slider')}
|
||||
</Button>
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeSideBySide}
|
||||
colorScheme={comparisonMode === 'side-by-side' ? 'invokeBlue' : 'base'}
|
||||
>
|
||||
{t('gallery.sideBySide')}
|
||||
</Button>
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeHover}
|
||||
colorScheme={comparisonMode === 'hover' ? 'invokeBlue' : 'base'}
|
||||
>
|
||||
{t('gallery.hover')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto" alignItems="center">
|
||||
<Tooltip label={<CompareHelp />}>
|
||||
<Flex alignItems="center">
|
||||
<Icon boxSize={8} color="base.500" as={PiQuestion} lineHeight={0} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
<IconButton
|
||||
icon={<PiXBold />}
|
||||
aria-label={`${t('gallery.exitCompare')} (Esc)`}
|
||||
tooltip={`${t('gallery.exitCompare')} (Esc)`}
|
||||
onClick={exitCompare}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CompareToolbar.displayName = 'CompareToolbar';
|
||||
|
||||
const CompareHelp = () => {
|
||||
return (
|
||||
<UnorderedList>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp1" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp2" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp3" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp4" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
</UnorderedList>
|
||||
);
|
||||
};
|
@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import type { TypesafeDraggableData } from 'features/dnd/types';
|
||||
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
@ -22,21 +22,7 @@ const selectLastSelectedImageName = createSelector(
|
||||
(lastSelectedImage) => lastSelectedImage?.image_name
|
||||
);
|
||||
|
||||
type Props = {
|
||||
isDragDisabled?: boolean;
|
||||
isDropDisabled?: boolean;
|
||||
withNextPrevButtons?: boolean;
|
||||
withMetadata?: boolean;
|
||||
alwaysShowProgress?: boolean;
|
||||
};
|
||||
|
||||
const CurrentImagePreview = ({
|
||||
isDragDisabled = false,
|
||||
isDropDisabled = false,
|
||||
withNextPrevButtons = true,
|
||||
withMetadata = true,
|
||||
alwaysShowProgress = false,
|
||||
}: Props) => {
|
||||
const CurrentImagePreview = () => {
|
||||
const { t } = useTranslation();
|
||||
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
||||
const imageName = useAppSelector(selectLastSelectedImageName);
|
||||
@ -55,14 +41,6 @@ const CurrentImagePreview = ({
|
||||
}
|
||||
}, [imageDTO]);
|
||||
|
||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
||||
() => ({
|
||||
id: 'current-image',
|
||||
actionType: 'SET_CURRENT_IMAGE',
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
// Show and hide the next/prev buttons on mouse move
|
||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
|
||||
const timeoutId = useRef(0);
|
||||
@ -86,30 +64,27 @@ const CurrentImagePreview = ({
|
||||
justifyContent="center"
|
||||
position="relative"
|
||||
>
|
||||
{hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? (
|
||||
{hasDenoiseProgress && shouldShowProgressInViewer ? (
|
||||
<ProgressImage />
|
||||
) : (
|
||||
<IAIDndImage
|
||||
imageDTO={imageDTO}
|
||||
droppableData={droppableData}
|
||||
draggableData={draggableData}
|
||||
isDragDisabled={isDragDisabled}
|
||||
isDropDisabled={isDropDisabled}
|
||||
isDropDisabled={true}
|
||||
isUploadDisabled={true}
|
||||
fitContainer
|
||||
useThumbailFallback
|
||||
dropLabel={t('gallery.setCurrentImage')}
|
||||
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
|
||||
dataTestId="image-preview"
|
||||
/>
|
||||
)}
|
||||
{shouldShowImageDetails && imageDTO && withMetadata && (
|
||||
{shouldShowImageDetails && imageDTO && (
|
||||
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
|
||||
<ImageMetadataViewer image={imageDTO} />
|
||||
</Box>
|
||||
)}
|
||||
<AnimatePresence>
|
||||
{withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && (
|
||||
{shouldShowNextPrevButtons && imageDTO && (
|
||||
<Box
|
||||
as={motion.div}
|
||||
key="nextPrevButtons"
|
||||
|
@ -0,0 +1,41 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
|
||||
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
|
||||
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
|
||||
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
containerDims: Dimensions;
|
||||
};
|
||||
|
||||
export const ImageComparison = memo(({ containerDims }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
|
||||
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
|
||||
|
||||
if (!firstImage || !secondImage) {
|
||||
// Should rarely/never happen - we don't render this component unless we have images to compare
|
||||
return <IAINoContentFallback label={t('gallery.selectAnImageToCompare')} icon={PiImagesBold} />;
|
||||
}
|
||||
|
||||
if (comparisonMode === 'slider') {
|
||||
return <ImageComparisonSlider containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
|
||||
}
|
||||
|
||||
if (comparisonMode === 'side-by-side') {
|
||||
return (
|
||||
<ImageComparisonSideBySide containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />
|
||||
);
|
||||
}
|
||||
|
||||
if (comparisonMode === 'hover') {
|
||||
return <ImageComparisonHover containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
|
||||
}
|
||||
});
|
||||
|
||||
ImageComparison.displayName = 'ImageComparison';
|
@ -0,0 +1,47 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { selectComparisonImages } from './common';
|
||||
|
||||
const setCurrentImageDropData: CurrentImageDropData = {
|
||||
id: 'current-image',
|
||||
actionType: 'SET_CURRENT_IMAGE',
|
||||
};
|
||||
|
||||
export const ImageComparisonDroppable = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const imageViewer = useImageViewer();
|
||||
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
|
||||
const selectForCompareDropData = useMemo<SelectForCompareDropData>(
|
||||
() => ({
|
||||
id: 'image-comparison',
|
||||
actionType: 'SELECT_FOR_COMPARE',
|
||||
context: {
|
||||
firstImageName: firstImage?.image_name,
|
||||
secondImageName: secondImage?.image_name,
|
||||
},
|
||||
}),
|
||||
[firstImage?.image_name, secondImage?.image_name]
|
||||
);
|
||||
|
||||
if (!imageViewer.isOpen) {
|
||||
return (
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||
<IAIDroppable data={setCurrentImageDropData} dropLabel={t('gallery.openInViewer')} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||
<IAIDroppable data={selectForCompareDropData} dropLabel={t('gallery.selectForCompare')} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonDroppable.displayName = 'ImageComparisonDroppable';
|
@ -0,0 +1,117 @@
|
||||
import { Box, Flex, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { preventDefault } from 'common/util/stopPropagation';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
|
||||
import type { ComparisonProps } from './common';
|
||||
import { fitDimsToContainer, getSecondImageDims } from './common';
|
||||
|
||||
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
|
||||
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||
const mouseOver = useBoolean(false);
|
||||
const fittedDims = useMemo<Dimensions>(
|
||||
() => fitDimsToContainer(containerDims, firstImage),
|
||||
[containerDims, firstImage]
|
||||
);
|
||||
const compareImageDims = useMemo<Dimensions>(
|
||||
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
|
||||
[comparisonFit, fittedDims, firstImage, secondImage]
|
||||
);
|
||||
return (
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||
<Flex
|
||||
id="image-comparison-wrapper"
|
||||
w="full"
|
||||
h="full"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
position="absolute"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Box
|
||||
ref={imageContainerRef}
|
||||
position="relative"
|
||||
id="image-comparison-hover-image-container"
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
userSelect="none"
|
||||
overflow="hidden"
|
||||
borderRadius="base"
|
||||
>
|
||||
<Image
|
||||
id="image-comparison-hover-first-image"
|
||||
src={firstImage.image_url}
|
||||
fallbackSrc={firstImage.thumbnail_url}
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
objectFit="cover"
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="first" opacity={mouseOver.isTrue ? 0 : 1} />
|
||||
|
||||
<Box
|
||||
id="image-comparison-hover-second-image-container"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
overflow="hidden"
|
||||
opacity={mouseOver.isTrue ? 1 : 0}
|
||||
transitionDuration="0.2s"
|
||||
transitionProperty="common"
|
||||
>
|
||||
<Box
|
||||
id="image-comparison-hover-bg"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
backgroundImage={STAGE_BG_DATAURL}
|
||||
backgroundRepeat="repeat"
|
||||
opacity={0.2}
|
||||
/>
|
||||
<Image
|
||||
position="relative"
|
||||
id="image-comparison-hover-second-image"
|
||||
src={secondImage.image_url}
|
||||
fallbackSrc={secondImage.thumbnail_url}
|
||||
w={compareImageDims.width}
|
||||
h={compareImageDims.height}
|
||||
maxW={fittedDims.width}
|
||||
maxH={fittedDims.height}
|
||||
objectFit={comparisonFit}
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="second" opacity={mouseOver.isTrue ? 1 : 0} />
|
||||
</Box>
|
||||
<Box
|
||||
id="image-comparison-hover-interaction-overlay"
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
onMouseOver={mouseOver.setTrue}
|
||||
onMouseOut={mouseOver.setFalse}
|
||||
onContextMenu={preventDefault}
|
||||
userSelect="none"
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonHover.displayName = 'ImageComparisonHover';
|
@ -0,0 +1,33 @@
|
||||
import type { TextProps } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { DROP_SHADOW } from './common';
|
||||
|
||||
type Props = TextProps & {
|
||||
type: 'first' | 'second';
|
||||
};
|
||||
|
||||
export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Text
|
||||
position="absolute"
|
||||
bottom={4}
|
||||
insetInlineEnd={type === 'first' ? undefined : 4}
|
||||
insetInlineStart={type === 'first' ? 4 : undefined}
|
||||
textOverflow="clip"
|
||||
whiteSpace="nowrap"
|
||||
filter={DROP_SHADOW}
|
||||
color="base.50"
|
||||
transitionDuration="0.2s"
|
||||
transitionProperty="common"
|
||||
{...rest}
|
||||
>
|
||||
{type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')}
|
||||
</Text>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonLabel.displayName = 'ImageComparisonLabel';
|
@ -0,0 +1,70 @@
|
||||
import { Flex, Image } from '@invoke-ai/ui-library';
|
||||
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
|
||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
|
||||
export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => {
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
const onDoubleClickHandle = useCallback(() => {
|
||||
if (!panelGroupRef.current) {
|
||||
return;
|
||||
}
|
||||
panelGroupRef.current.setLayout([50, 50]);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="absolute" alignItems="center" justifyContent="center">
|
||||
<PanelGroup ref={panelGroupRef} direction="horizontal" id="image-comparison-side-by-side">
|
||||
<Panel minSize={20}>
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={firstImage.width / firstImage.height}>
|
||||
<Image
|
||||
id="image-comparison-side-by-side-first-image"
|
||||
w={firstImage.width}
|
||||
h={firstImage.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
src={firstImage.image_url}
|
||||
fallbackSrc={firstImage.thumbnail_url}
|
||||
objectFit="contain"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<ImageComparisonLabel type="first" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Panel>
|
||||
<ResizeHandle
|
||||
id="image-comparison-side-by-side-handle"
|
||||
onDoubleClick={onDoubleClickHandle}
|
||||
orientation="vertical"
|
||||
/>
|
||||
|
||||
<Panel minSize={20}>
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={secondImage.width / secondImage.height}>
|
||||
<Image
|
||||
id="image-comparison-side-by-side-first-image"
|
||||
w={secondImage.width}
|
||||
h={secondImage.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
src={secondImage.image_url}
|
||||
fallbackSrc={secondImage.thumbnail_url}
|
||||
objectFit="contain"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<ImageComparisonLabel type="second" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide';
|
@ -0,0 +1,215 @@
|
||||
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { preventDefault } from 'common/util/stopPropagation';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
import type { ComparisonProps } from './common';
|
||||
import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common';
|
||||
|
||||
const INITIAL_POS = '50%';
|
||||
const HANDLE_WIDTH = 2;
|
||||
const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`;
|
||||
const HANDLE_HITBOX = 20;
|
||||
const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
|
||||
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
|
||||
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
|
||||
|
||||
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
|
||||
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
|
||||
const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX);
|
||||
// How wide the first image is
|
||||
const [width, setWidth] = useState(INITIAL_POS);
|
||||
const handleRef = useRef<HTMLDivElement>(null);
|
||||
// To manage aspect ratios, we need to know the size of the container
|
||||
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||
// To keep things smooth, we use RAF to update the handle position & gate it to 60fps
|
||||
const rafRef = useRef<number | null>(null);
|
||||
const lastMoveTimeRef = useRef<number>(0);
|
||||
|
||||
const fittedDims = useMemo<Dimensions>(
|
||||
() => fitDimsToContainer(containerDims, firstImage),
|
||||
[containerDims, firstImage]
|
||||
);
|
||||
|
||||
const compareImageDims = useMemo<Dimensions>(
|
||||
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
|
||||
[comparisonFit, fittedDims, firstImage, secondImage]
|
||||
);
|
||||
|
||||
const updateHandlePos = useCallback((clientX: number) => {
|
||||
if (!handleRef.current || !imageContainerRef.current) {
|
||||
return;
|
||||
}
|
||||
lastMoveTimeRef.current = performance.now();
|
||||
const { x, width } = imageContainerRef.current.getBoundingClientRect();
|
||||
const rawHandlePos = ((clientX - x) * 100) / width;
|
||||
const handleWidthPct = (HANDLE_WIDTH * 100) / width;
|
||||
const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos));
|
||||
setWidth(`${newHandlePos}%`);
|
||||
setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`);
|
||||
}, []);
|
||||
|
||||
const onMouseMove = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) {
|
||||
rafRef.current = window.requestAnimationFrame(() => {
|
||||
updateHandlePos(e.clientX);
|
||||
rafRef.current = null;
|
||||
});
|
||||
}
|
||||
},
|
||||
[updateHandlePos]
|
||||
);
|
||||
|
||||
const onMouseUp = useCallback(() => {
|
||||
window.removeEventListener('mousemove', onMouseMove);
|
||||
}, [onMouseMove]);
|
||||
|
||||
const onMouseDown = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
// Update the handle position immediately on click
|
||||
updateHandlePos(e.clientX);
|
||||
window.addEventListener('mouseup', onMouseUp, { once: true });
|
||||
window.addEventListener('mousemove', onMouseMove);
|
||||
},
|
||||
[onMouseMove, onMouseUp, updateHandlePos]
|
||||
);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
if (rafRef.current !== null) {
|
||||
cancelAnimationFrame(rafRef.current);
|
||||
}
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||
<Flex
|
||||
id="image-comparison-wrapper"
|
||||
w="full"
|
||||
h="full"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
position="absolute"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Box
|
||||
ref={imageContainerRef}
|
||||
position="relative"
|
||||
id="image-comparison-image-container"
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
userSelect="none"
|
||||
overflow="hidden"
|
||||
borderRadius="base"
|
||||
>
|
||||
<Box
|
||||
id="image-comparison-bg"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
backgroundImage={STAGE_BG_DATAURL}
|
||||
backgroundRepeat="repeat"
|
||||
opacity={0.2}
|
||||
/>
|
||||
<Image
|
||||
position="relative"
|
||||
id="image-comparison-second-image"
|
||||
src={secondImage.image_url}
|
||||
fallbackSrc={secondImage.thumbnail_url}
|
||||
w={compareImageDims.width}
|
||||
h={compareImageDims.height}
|
||||
maxW={fittedDims.width}
|
||||
maxH={fittedDims.height}
|
||||
objectFit={comparisonFit}
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="second" />
|
||||
<Box
|
||||
id="image-comparison-first-image-container"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
w={width}
|
||||
overflow="hidden"
|
||||
>
|
||||
<Image
|
||||
id="image-comparison-first-image"
|
||||
src={firstImage.image_url}
|
||||
fallbackSrc={firstImage.thumbnail_url}
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
objectFit="cover"
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="first" />
|
||||
</Box>
|
||||
<Flex
|
||||
id="image-comparison-handle"
|
||||
ref={handleRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
bottom={0}
|
||||
left={left}
|
||||
w={HANDLE_HITBOX_PX}
|
||||
cursor="ew-resize"
|
||||
filter={DROP_SHADOW}
|
||||
opacity={0.8}
|
||||
color="base.50"
|
||||
>
|
||||
<Box
|
||||
id="image-comparison-handle-divider"
|
||||
w={HANDLE_WIDTH_PX}
|
||||
h="full"
|
||||
bg="currentColor"
|
||||
shadow="dark-lg"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={HANDLE_INNER_LEFT_PX}
|
||||
/>
|
||||
<Flex
|
||||
id="image-comparison-handle-icons"
|
||||
gap={4}
|
||||
position="absolute"
|
||||
left="50%"
|
||||
top="50%"
|
||||
transform="translate(-50%, 0)"
|
||||
filter={DROP_SHADOW}
|
||||
>
|
||||
<Icon as={PiCaretLeftBold} />
|
||||
<Icon as={PiCaretRightBold} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Box
|
||||
id="image-comparison-interaction-overlay"
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
onMouseDown={onMouseDown}
|
||||
onContextMenu={preventDefault}
|
||||
userSelect="none"
|
||||
cursor="ew-resize"
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonSlider.displayName = 'ImageComparisonSlider';
|
@ -1,36 +1,16 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
|
||||
import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview';
|
||||
import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison';
|
||||
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
|
||||
import { memo } from 'react';
|
||||
import { useMeasure } from 'react-use';
|
||||
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||
|
||||
const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows'];
|
||||
import { useImageViewer } from './useImageViewer';
|
||||
|
||||
export const ImageViewer = memo(() => {
|
||||
const { isOpen, onToggle, onClose } = useImageViewer();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]);
|
||||
const shouldShowViewer = useMemo(() => {
|
||||
if (!isViewerEnabled) {
|
||||
return false;
|
||||
}
|
||||
return isOpen;
|
||||
}, [isOpen, isViewerEnabled]);
|
||||
|
||||
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
|
||||
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);
|
||||
|
||||
if (!shouldShowViewer) {
|
||||
return null;
|
||||
}
|
||||
const imageViewer = useImageViewer();
|
||||
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -46,25 +26,13 @@ export const ImageViewer = memo(() => {
|
||||
rowGap={4}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
|
||||
>
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
<ViewerToggleMenu />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<CurrentImagePreview />
|
||||
{imageViewer.isComparing && <CompareToolbar />}
|
||||
{!imageViewer.isComparing && <ViewerToolbar />}
|
||||
<Box ref={containerRef} w="full" h="full">
|
||||
{!imageViewer.isComparing && <CurrentImagePreview />}
|
||||
{imageViewer.isComparing && <ImageComparison containerDims={containerDims} />}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
@ -1,45 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { memo } from 'react';
|
||||
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
|
||||
export const ImageViewerWorkflows = memo(() => {
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="first"
|
||||
borderRadius="base"
|
||||
position="absolute"
|
||||
flexDirection="column"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
p={2}
|
||||
rowGap={4}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
|
||||
>
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<CurrentImagePreview />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageViewerWorkflows.displayName = 'ImageViewerWorkflows';
|
@ -9,33 +9,35 @@ import {
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
|
||||
|
||||
import { useImageViewer } from './useImageViewer';
|
||||
|
||||
export const ViewerToggleMenu = () => {
|
||||
const { t } = useTranslation();
|
||||
const { isOpen, onClose, onOpen } = useImageViewer();
|
||||
const imageViewer = useImageViewer();
|
||||
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
|
||||
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<Button variant="outline" data-testid="toggle-viewer-menu-button">
|
||||
<Button variant="outline" data-testid="toggle-viewer-menu-button" pointerEvents="auto">
|
||||
<Flex gap={3} w="full" alignItems="center">
|
||||
{isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
|
||||
<Text fontSize="md">{isOpen ? t('common.viewing') : t('common.editing')}</Text>
|
||||
{imageViewer.isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
|
||||
<Text fontSize="md">{imageViewer.isOpen ? t('common.viewing') : t('common.editing')}</Text>
|
||||
<Icon as={PiCaretDownBold} />
|
||||
</Flex>
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent p={2}>
|
||||
<PopoverContent p={2} pointerEvents="auto">
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<Flex flexDir="column">
|
||||
<Button onClick={onOpen} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Button onClick={imageViewer.onOpen} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Flex gap={2} w="full">
|
||||
<Icon as={PiCheckBold} visibility={isOpen ? 'visible' : 'hidden'} />
|
||||
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'visible' : 'hidden'} />
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Text fontWeight="semibold" color="base.100">
|
||||
{t('common.viewing')}
|
||||
@ -46,9 +48,9 @@ export const ViewerToggleMenu = () => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Button>
|
||||
<Button onClick={onClose} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Button onClick={imageViewer.onClose} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Flex gap={2} w="full">
|
||||
<Icon as={PiCheckBold} visibility={isOpen ? 'hidden' : 'visible'} />
|
||||
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'hidden' : 'visible'} />
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Text fontWeight="semibold" color="base.100">
|
||||
{t('common.editing')}
|
||||
|
@ -0,0 +1,33 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||
|
||||
export const ViewerToolbar = memo(() => {
|
||||
const tab = useAppSelector(activeTabNameSelector);
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
{tab !== 'workflows' && <ViewerToggleMenu />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ViewerToolbar.displayName = 'ViewerToolbar';
|
@ -0,0 +1,64 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import type { ComparisonFit } from 'features/gallery/store/types';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
|
||||
|
||||
export type ComparisonProps = {
|
||||
firstImage: ImageDTO;
|
||||
secondImage: ImageDTO;
|
||||
containerDims: Dimensions;
|
||||
};
|
||||
|
||||
export const fitDimsToContainer = (containerDims: Dimensions, imageDims: Dimensions): Dimensions => {
|
||||
// Fall back to the image's dimensions if the container has no dimensions
|
||||
if (containerDims.width === 0 || containerDims.height === 0) {
|
||||
return { width: imageDims.width, height: imageDims.height };
|
||||
}
|
||||
|
||||
// Fall back to the image's dimensions if the image fits within the container
|
||||
if (imageDims.width <= containerDims.width && imageDims.height <= containerDims.height) {
|
||||
return { width: imageDims.width, height: imageDims.height };
|
||||
}
|
||||
|
||||
const targetAspectRatio = containerDims.width / containerDims.height;
|
||||
const imageAspectRatio = imageDims.width / imageDims.height;
|
||||
|
||||
let width: number;
|
||||
let height: number;
|
||||
|
||||
if (imageAspectRatio > targetAspectRatio) {
|
||||
// Image is wider than container's aspect ratio
|
||||
width = containerDims.width;
|
||||
height = width / imageAspectRatio;
|
||||
} else {
|
||||
// Image is taller than container's aspect ratio
|
||||
height = containerDims.height;
|
||||
width = height * imageAspectRatio;
|
||||
}
|
||||
return { width, height };
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the dimensions of the second image in a comparison based on the comparison fit mode.
|
||||
*/
|
||||
export const getSecondImageDims = (
|
||||
comparisonFit: ComparisonFit,
|
||||
fittedDims: Dimensions,
|
||||
firstImageDims: Dimensions,
|
||||
secondImageDims: Dimensions
|
||||
): Dimensions => {
|
||||
const width =
|
||||
comparisonFit === 'fill' ? fittedDims.width : (fittedDims.width * secondImageDims.width) / firstImageDims.width;
|
||||
const height =
|
||||
comparisonFit === 'fill' ? fittedDims.height : (fittedDims.height * secondImageDims.height) / firstImageDims.height;
|
||||
|
||||
return { width, height };
|
||||
};
|
||||
export const selectComparisonImages = createMemoizedSelector(selectGallerySlice, (gallerySlice) => {
|
||||
const firstImage = gallerySlice.selection.slice(-1)[0] ?? null;
|
||||
const secondImage = gallerySlice.imageToCompare;
|
||||
return { firstImage, secondImage };
|
||||
});
|
@ -0,0 +1,31 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const useImageViewer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isComparing = useAppSelector((s) => s.gallery.imageToCompare !== null);
|
||||
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
if (isComparing && isOpen) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(isImageViewerOpenChanged(false));
|
||||
}
|
||||
}, [dispatch, isComparing, isOpen]);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}, [dispatch]);
|
||||
|
||||
const onToggle = useCallback(() => {
|
||||
if (isComparing && isOpen) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(isImageViewerOpenChanged(!isOpen));
|
||||
}
|
||||
}, [dispatch, isComparing, isOpen]);
|
||||
|
||||
return { isOpen, onOpen, onClose, onToggle, isComparing };
|
||||
};
|
@ -1,22 +0,0 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const useImageViewer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(false));
|
||||
}, [dispatch]);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}, [dispatch]);
|
||||
|
||||
const onToggle = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(!isOpen));
|
||||
}, [dispatch, isOpen]);
|
||||
|
||||
return { isOpen, onOpen, onClose, onToggle };
|
||||
};
|
@ -14,7 +14,7 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
||||
const NextPrevImageButtons = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
@ -30,7 +30,7 @@ const NextPrevImageButtons = () => {
|
||||
aria-label={t('accessibility.previousImage')}
|
||||
icon={<PiCaretLeftBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={handleLeftImage}
|
||||
onClick={prevImage}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
@ -42,7 +42,7 @@ const NextPrevImageButtons = () => {
|
||||
aria-label={t('accessibility.nextImage')}
|
||||
icon={<PiCaretRightBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={handleRightImage}
|
||||
onClick={nextImage}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
|
@ -27,16 +27,16 @@ export const useGalleryHotkeys = () => {
|
||||
useGalleryNavigation();
|
||||
|
||||
useHotkeys(
|
||||
'left',
|
||||
() => {
|
||||
canNavigateGallery && handleLeftImage();
|
||||
['left', 'alt+left'],
|
||||
(e) => {
|
||||
canNavigateGallery && handleLeftImage(e.altKey);
|
||||
},
|
||||
[handleLeftImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'right',
|
||||
() => {
|
||||
['right', 'alt+right'],
|
||||
(e) => {
|
||||
if (!canNavigateGallery) {
|
||||
return;
|
||||
}
|
||||
@ -45,29 +45,29 @@ export const useGalleryHotkeys = () => {
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImage) {
|
||||
handleRightImage();
|
||||
handleRightImage(e.altKey);
|
||||
}
|
||||
},
|
||||
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'up',
|
||||
() => {
|
||||
handleUpImage();
|
||||
['up', 'alt+up'],
|
||||
(e) => {
|
||||
handleUpImage(e.altKey);
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[handleUpImage]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'down',
|
||||
() => {
|
||||
['down', 'alt+down'],
|
||||
(e) => {
|
||||
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
handleDownImage();
|
||||
handleDownImage(e.altKey);
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { useAltModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
||||
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { getIsVisible } from 'features/gallery/util/getIsVisible';
|
||||
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
|
||||
import { clamp } from 'lodash-es';
|
||||
@ -106,10 +106,12 @@ const getImageFuncs = {
|
||||
};
|
||||
|
||||
type UseGalleryNavigationReturn = {
|
||||
handleLeftImage: () => void;
|
||||
handleRightImage: () => void;
|
||||
handleUpImage: () => void;
|
||||
handleDownImage: () => void;
|
||||
handleLeftImage: (alt?: boolean) => void;
|
||||
handleRightImage: (alt?: boolean) => void;
|
||||
handleUpImage: (alt?: boolean) => void;
|
||||
handleDownImage: (alt?: boolean) => void;
|
||||
prevImage: () => void;
|
||||
nextImage: () => void;
|
||||
isOnFirstImage: boolean;
|
||||
isOnLastImage: boolean;
|
||||
areImagesBelowCurrent: boolean;
|
||||
@ -123,7 +125,15 @@ type UseGalleryNavigationReturn = {
|
||||
*/
|
||||
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
const dispatch = useAppDispatch();
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const alt = useAltModifier();
|
||||
const lastSelectedImage = useAppSelector((s) => {
|
||||
const lastSelected = s.gallery.selection.slice(-1)[0] ?? null;
|
||||
if (alt) {
|
||||
return s.gallery.imageToCompare ?? lastSelected;
|
||||
} else {
|
||||
return lastSelected;
|
||||
}
|
||||
});
|
||||
const {
|
||||
queryResult: { data },
|
||||
} = useGalleryImages();
|
||||
@ -136,7 +146,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
}, [lastSelectedImage, data]);
|
||||
|
||||
const handleNavigation = useCallback(
|
||||
(direction: 'left' | 'right' | 'up' | 'down') => {
|
||||
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
|
||||
if (!data) {
|
||||
return;
|
||||
}
|
||||
@ -144,10 +154,14 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
if (!image || index === lastSelectedImageIndex) {
|
||||
return;
|
||||
}
|
||||
if (alt) {
|
||||
dispatch(imageToCompareChanged(image));
|
||||
} else {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
scrollToImage(image.image_name, index);
|
||||
},
|
||||
[dispatch, lastSelectedImageIndex, data]
|
||||
[data, lastSelectedImageIndex, dispatch]
|
||||
);
|
||||
|
||||
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
||||
@ -162,21 +176,41 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const handleLeftImage = useCallback(() => {
|
||||
handleNavigation('left');
|
||||
}, [handleNavigation]);
|
||||
const handleLeftImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('left', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
|
||||
const handleRightImage = useCallback(() => {
|
||||
handleNavigation('right');
|
||||
}, [handleNavigation]);
|
||||
const handleRightImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('right', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
|
||||
const handleUpImage = useCallback(() => {
|
||||
handleNavigation('up');
|
||||
}, [handleNavigation]);
|
||||
const handleUpImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('up', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
|
||||
const handleDownImage = useCallback(() => {
|
||||
handleNavigation('down');
|
||||
}, [handleNavigation]);
|
||||
const handleDownImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('down', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
|
||||
const nextImage = useCallback(() => {
|
||||
handleRightImage();
|
||||
}, [handleRightImage]);
|
||||
|
||||
const prevImage = useCallback(() => {
|
||||
handleLeftImage();
|
||||
}, [handleLeftImage]);
|
||||
|
||||
return {
|
||||
handleLeftImage,
|
||||
@ -186,5 +220,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
isOnFirstImage,
|
||||
isOnLastImage,
|
||||
areImagesBelowCurrent,
|
||||
nextImage,
|
||||
prevImage,
|
||||
};
|
||||
};
|
||||
|
@ -36,6 +36,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
||||
shiftKey: e.shiftKey,
|
||||
ctrlKey: e.ctrlKey,
|
||||
metaKey: e.metaKey,
|
||||
altKey: e.altKey,
|
||||
})
|
||||
);
|
||||
},
|
||||
|
@ -6,7 +6,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { BoardId, GalleryState, GalleryView } from './types';
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
|
||||
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
||||
|
||||
const initialGalleryState: GalleryState = {
|
||||
@ -22,6 +22,9 @@ const initialGalleryState: GalleryState = {
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
offset: 0,
|
||||
isImageViewerOpen: true,
|
||||
imageToCompare: null,
|
||||
comparisonMode: 'slider',
|
||||
comparisonFit: 'fill',
|
||||
};
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
@ -34,6 +37,28 @@ export const gallerySlice = createSlice({
|
||||
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||
state.selection = uniqBy(action.payload, (i) => i.image_name);
|
||||
},
|
||||
imageToCompareChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||
state.imageToCompare = action.payload;
|
||||
if (action.payload) {
|
||||
state.isImageViewerOpen = true;
|
||||
}
|
||||
},
|
||||
comparisonModeChanged: (state, action: PayloadAction<ComparisonMode>) => {
|
||||
state.comparisonMode = action.payload;
|
||||
},
|
||||
comparisonModeCycled: (state) => {
|
||||
switch (state.comparisonMode) {
|
||||
case 'slider':
|
||||
state.comparisonMode = 'side-by-side';
|
||||
break;
|
||||
case 'side-by-side':
|
||||
state.comparisonMode = 'hover';
|
||||
break;
|
||||
case 'hover':
|
||||
state.comparisonMode = 'slider';
|
||||
break;
|
||||
}
|
||||
},
|
||||
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAutoSwitch = action.payload;
|
||||
},
|
||||
@ -79,6 +104,16 @@ export const gallerySlice = createSlice({
|
||||
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.isImageViewerOpen = action.payload;
|
||||
},
|
||||
comparedImagesSwapped: (state) => {
|
||||
if (state.imageToCompare) {
|
||||
const oldSelection = state.selection;
|
||||
state.selection = [state.imageToCompare];
|
||||
state.imageToCompare = oldSelection[0] ?? null;
|
||||
}
|
||||
},
|
||||
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
|
||||
state.comparisonFit = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||
@ -117,6 +152,11 @@ export const {
|
||||
moreImagesLoaded,
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
isImageViewerOpenChanged,
|
||||
imageToCompareChanged,
|
||||
comparisonModeChanged,
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeCycled,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
const isAnyBoardDeleted = isAnyOf(
|
||||
@ -138,5 +178,13 @@ export const galleryPersistConfig: PersistConfig<GalleryState> = {
|
||||
name: gallerySlice.name,
|
||||
initialState: initialGalleryState,
|
||||
migrate: migrateGalleryState,
|
||||
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'offset', 'limit', 'isImageViewerOpen'],
|
||||
persistDenylist: [
|
||||
'selection',
|
||||
'selectedBoardId',
|
||||
'galleryView',
|
||||
'offset',
|
||||
'limit',
|
||||
'isImageViewerOpen',
|
||||
'imageToCompare',
|
||||
],
|
||||
};
|
||||
|
@ -7,6 +7,8 @@ export const IMAGE_LIMIT = 20;
|
||||
|
||||
export type GalleryView = 'images' | 'assets';
|
||||
export type BoardId = 'none' | (string & Record<never, never>);
|
||||
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
|
||||
export type ComparisonFit = 'contain' | 'fill';
|
||||
|
||||
export type GalleryState = {
|
||||
selection: ImageDTO[];
|
||||
@ -20,5 +22,8 @@ export type GalleryState = {
|
||||
offset: number;
|
||||
limit: number;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
imageToCompare: ImageDTO | null;
|
||||
comparisonMode: ComparisonMode;
|
||||
comparisonFit: ComparisonFit;
|
||||
isImageViewerOpen: boolean;
|
||||
};
|
||||
|
@ -19,7 +19,7 @@ import {
|
||||
redo,
|
||||
undo,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { $flow, $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@ -68,6 +68,7 @@ export const Flow = memo(() => {
|
||||
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
||||
const edges = useAppSelector((s) => s.nodes.present.edges);
|
||||
const viewport = useStore($viewport);
|
||||
const needsFit = useStore($needsFit);
|
||||
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
|
||||
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
||||
@ -92,8 +93,16 @@ export const Flow = memo(() => {
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(nodeChanges) => {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
const flow = $flow.get();
|
||||
if (!flow) {
|
||||
return;
|
||||
}
|
||||
if (needsFit) {
|
||||
$needsFit.set(false);
|
||||
flow.fitView();
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
[dispatch, needsFit]
|
||||
);
|
||||
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
|
@ -15,27 +15,20 @@ const ViewportControls = () => {
|
||||
const { t } = useTranslation();
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const dispatch = useAppDispatch();
|
||||
// const shouldShowFieldTypeLegend = useAppSelector(
|
||||
// (s) => s.nodes.present.shouldShowFieldTypeLegend
|
||||
// );
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
|
||||
|
||||
const handleClickedZoomIn = useCallback(() => {
|
||||
zoomIn();
|
||||
zoomIn({ duration: 300 });
|
||||
}, [zoomIn]);
|
||||
|
||||
const handleClickedZoomOut = useCallback(() => {
|
||||
zoomOut();
|
||||
zoomOut({ duration: 300 });
|
||||
}, [zoomOut]);
|
||||
|
||||
const handleClickedFitView = useCallback(() => {
|
||||
fitView();
|
||||
fitView({ duration: 300 });
|
||||
}, [fitView]);
|
||||
|
||||
// const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
||||
// dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
||||
// }, [shouldShowFieldTypeLegend, dispatch]);
|
||||
|
||||
const handleClickedToggleMiniMapPanel = useCallback(() => {
|
||||
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
|
||||
}, [shouldShowMinimapPanel, dispatch]);
|
||||
|
@ -1,9 +1,7 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import QueueControls from 'features/queue/components/QueueControls';
|
||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
|
||||
@ -21,14 +19,8 @@ import { WorkflowName } from './WorkflowName';
|
||||
|
||||
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
|
||||
|
||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
||||
return {
|
||||
mode: workflow.mode,
|
||||
};
|
||||
});
|
||||
|
||||
const NodeEditorPanelGroup = () => {
|
||||
const { mode } = useAppSelector(selector);
|
||||
const mode = useAppSelector((s) => s.workflow.mode);
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
const panelStorage = usePanelStorage();
|
||||
|
||||
|
@ -1,20 +1,12 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
|
||||
import { ModeToggle } from './ModeToggle';
|
||||
|
||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
||||
return {
|
||||
mode: workflow.mode,
|
||||
};
|
||||
});
|
||||
|
||||
export const WorkflowMenu = () => {
|
||||
const { mode } = useAppSelector(selector);
|
||||
const mode = useAppSelector((s) => s.workflow.mode);
|
||||
|
||||
return (
|
||||
<Flex gap="2" alignItems="center">
|
||||
|
@ -11,8 +11,7 @@ import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ImageOutput } from 'services/api/types';
|
||||
import type { AnyResult } from 'services/events/types';
|
||||
import type { AnyInvocationOutput, ImageOutput } from 'services/api/types';
|
||||
|
||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||
|
||||
@ -66,4 +65,4 @@ const InspectorOutputsTab = () => {
|
||||
|
||||
export default memo(InspectorOutputsTab);
|
||||
|
||||
const getKey = (result: AnyResult, i: number) => `${result.type}-${i}`;
|
||||
const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`;
|
||||
|
@ -2,3 +2,4 @@ import { atom } from 'nanostores';
|
||||
import type { ReactFlowInstance } from 'reactflow';
|
||||
|
||||
export const $flow = atom<ReactFlowInstance | null>(null);
|
||||
export const $needsFit = atom<boolean>(true);
|
||||
|
@ -144,5 +144,4 @@ const zImageOutput = z.object({
|
||||
type: z.literal('image_output'),
|
||||
});
|
||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
|
||||
// #endregion
|
||||
|
@ -1,8 +1,7 @@
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import type { AnyInvocation } from 'services/events/types';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
/**
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ImageViewerWorkflows } from 'features/gallery/components/ImageViewer/ImageViewerWorkflows';
|
||||
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
|
||||
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
||||
import NodeEditor from 'features/nodes/components/NodeEditor';
|
||||
import { memo } from 'react';
|
||||
import { ReactFlowProvider } from 'reactflow';
|
||||
@ -10,7 +11,8 @@ const NodesTab = () => {
|
||||
if (mode === 'view') {
|
||||
return (
|
||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||
<ImageViewerWorkflows />
|
||||
<ImageViewer />
|
||||
<ImageComparisonDroppable />
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
@ -1,13 +1,17 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { ControlLayersEditor } from 'features/controlLayers/components/ControlLayersEditor';
|
||||
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
|
||||
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo } from 'react';
|
||||
|
||||
const TextToImageTab = () => {
|
||||
const imageViewer = useImageViewer();
|
||||
return (
|
||||
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
|
||||
<ControlLayersEditor />
|
||||
<ImageViewer />
|
||||
{imageViewer.isOpen && <ImageViewer />}
|
||||
<ImageComparisonDroppable />
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -41,7 +41,7 @@ const UnifiedCanvasTab = () => {
|
||||
>
|
||||
<IAICanvasToolbar />
|
||||
<IAICanvas />
|
||||
{isValidDrop(droppableData, active) && (
|
||||
{isValidDrop(droppableData, active?.data.current) && (
|
||||
<IAIDropOverlay isOver={isOver} label={t('toast.setCanvasInitialImage')} />
|
||||
)}
|
||||
</Flex>
|
||||
|
File diff suppressed because one or more lines are too long
@ -122,7 +122,6 @@ export type ModelInstallStatus = S['InstallStatus'];
|
||||
// Graphs
|
||||
export type Graph = S['Graph'];
|
||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||
export type GraphExecutionState = S['GraphExecutionState'];
|
||||
export type Batch = S['Batch'];
|
||||
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
|
||||
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
|
||||
@ -132,14 +131,14 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
|
||||
type KeysOfUnion<T> = T extends T ? keyof T : never;
|
||||
|
||||
export type AnyInvocation = Exclude<
|
||||
Graph['nodes'][string],
|
||||
NonNullable<S['Graph']['nodes']>[string],
|
||||
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
|
||||
>;
|
||||
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
|
||||
export type AnyInvocationIncMetadata = NonNullable<S['Graph']['nodes']>[string];
|
||||
|
||||
export type InvocationType = AnyInvocation['type'];
|
||||
type InvocationOutputMap = S['InvocationOutputMap'];
|
||||
type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
||||
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
||||
|
||||
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
|
||||
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
|
||||
|
@ -1,21 +1,12 @@
|
||||
import type { Graph, GraphExecutionState, S } from 'services/api/types';
|
||||
|
||||
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
|
||||
|
||||
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
||||
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
||||
|
||||
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation };
|
||||
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & {
|
||||
invocation: AnyInvocation;
|
||||
};
|
||||
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
|
||||
result: AnyResult;
|
||||
invocation: AnyInvocation;
|
||||
};
|
||||
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
|
||||
export type InvocationStartedEvent = S['InvocationStartedEvent'];
|
||||
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
|
||||
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
|
||||
export type InvocationErrorEvent = S['InvocationErrorEvent'];
|
||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
||||
|
||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||
|
@ -55,10 +55,10 @@ dependencies = [
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.0",
|
||||
"fastapi==0.110.0",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.23.1",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.6.3",
|
||||
"pydantic==2.7.2",
|
||||
"python-socketio==5.11.1",
|
||||
"uvicorn[standard]==0.28.0",
|
||||
|
||||
|
@ -7,9 +7,10 @@ def main():
|
||||
# Change working directory to the repo root
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from invokeai.app.api_app import custom_openapi
|
||||
from invokeai.app.api_app import app
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
|
||||
schema = custom_openapi()
|
||||
schema = get_openapi_func(app)()
|
||||
json.dump(schema, sys.stdout, indent=2)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
|
||||
def test_graph_can_generate_schema():
|
||||
# Not throwing on this line is sufficient
|
||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||
_ = Graph.model_json_schema()
|
||||
models_json_schema([(Graph, "serialization")])
|
||||
|
Loading…
Reference in New Issue
Block a user