Merge branch 'main' into lstein/feat/simple-mm2-api

This commit is contained in:
Lincoln Stein 2024-06-02 09:45:31 -04:00 committed by GitHub
commit 2276f327e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 3664 additions and 1949 deletions

View File

@ -18,6 +18,7 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
@ -70,3 +71,6 @@ installer-zip:
tag-release:
cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself

View File

@ -3,9 +3,7 @@ import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import torch
import uvicorn
@ -13,11 +11,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
@ -45,11 +39,6 @@ from .api.routers import (
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
app.openapi = get_openapi_func(app)
@app.get("/docs", include_in_schema=False)

View File

@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"])
@classmethod
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"])
@abstractmethod
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True,
json_schema_serialization_defaults_required=False,
coerce_numbers_to_str=True,
)

View File

@ -1,11 +1,10 @@
from math import floor
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@ -33,6 +33,7 @@ class EventBase(BaseModel):
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@ -97,7 +98,7 @@ class InvocationEventBase(QueueItemEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@ -108,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
@ -135,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
@ -173,11 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):
__event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
result: AnyInvocationOutput = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
@ -206,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
invocation: AnyInvocation,
error_type: str,
error_message: str,
error_traceback: str,

View File

@ -2,18 +2,19 @@
import copy
import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
field_validator,
)
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
from pydantic_core import core_schema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel):
)
# The results of executed nodes
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute."""

View File

@ -0,0 +1,116 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@ -53,5 +53,5 @@ class ModelLocker(ModelLockerBase):
"""Call upon exit from context."""
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats()

View File

@ -1,7 +1,7 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0:
return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
# Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
# Do not exceed the max model input size. The -2 here is compensating for
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
max_length = self.tokenizer.model_max_length - 2
if len(new_token_ids) > max_length:
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
# prompts.
new_token_ids = new_token_ids[0:max_length]
return new_token_ids

View File

@ -148,6 +148,8 @@
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas",
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
"disabled": "Disabled"
},
@ -375,7 +377,23 @@
"bulkDownloadRequestFailed": "Problem Preparing Download",
"bulkDownloadFailed": "Download Failed",
"problemDeletingImages": "Problem Deleting Images",
"problemDeletingImagesDesc": "One or more images could not be deleted"
"problemDeletingImagesDesc": "One or more images could not be deleted",
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",
"sideBySide": "Side-by-Side",
"hover": "Hover",
"swapImages": "Swap Images",
"compareOptions": "Comparison Options",
"stretchToFit": "Stretch to Fit",
"exitCompare": "Exit Compare",
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
},
"hotkeys": {
"searchHotkeys": "Search Hotkeys",

View File

@ -19,6 +19,13 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return extendTheme({
..._theme,
direction,
shadows: {
..._theme.shadows,
selectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
hoverSelectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
},
});
}, [direction]);

View File

@ -13,7 +13,6 @@ import {
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
@ -139,7 +138,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// We still have to check the output type
assert(
isImageOutput(invocationCompleteAction.payload.data.result),
invocationCompleteAction.payload.data.result.type === 'image_output',
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;

View File

@ -9,7 +9,6 @@ import {
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
);
// We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received

View File

@ -1,7 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
@ -11,6 +11,7 @@ export const galleryImageClicked = createAction<{
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
altKey: boolean;
}>('gallery/imageClicked');
/**
@ -28,7 +29,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
startAppListening({
actionCreator: galleryImageClicked,
effect: async (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
@ -41,7 +42,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const selection = state.gallery.selection;
if (shiftKey) {
if (altKey) {
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageDTO));
}
} else if (shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);

View File

@ -14,7 +14,8 @@ import {
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@ -30,6 +31,9 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
effect: async (action, { dispatch, getState }) => {
const log = logger('dnd');
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
}
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
@ -50,6 +54,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
@ -182,24 +187,18 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
}
/**
* TODO
* Image selection dropped on node image collection field
* Image selected for compare
*/
// if (
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO
// ) {
// const { fieldName, nodeId } = overData.context;
// dispatch(
// fieldValueChanged({
// nodeId,
// fieldName,
// value: [activeData.payload.imageDTO],
// })
// );
// return;
// }
if (
overData.actionType === 'SELECT_FOR_COMPARE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(imageToCompareChanged(imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
/**
* Image dropped on user board

View File

@ -11,7 +11,6 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const { result, invocation_source_id } = data;
// This complete event has an associated image output
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image;
const { canvas, gallery } = getState();

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
@ -65,9 +65,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
});
}
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
$needsFit.set(true);
} catch (e) {
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions

View File

@ -35,6 +35,7 @@ type IAIDndImageProps = FlexProps & {
draggableData?: TypesafeDraggableData;
dropLabel?: ReactNode;
isSelected?: boolean;
isSelectedForCompare?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
@ -61,6 +62,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
draggableData,
dropLabel,
isSelected = false,
isSelectedForCompare = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
@ -165,7 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => {
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
<SelectionOverlay
isSelected={isSelected}
isSelectedForCompare={isSelectedForCompare}
isHovered={withHoverOverlay ? isHovered : false}
/>
</Flex>
)}
{!imageDTO && !isUploadDisabled && (

View File

@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
</AnimatePresence>
</Box>
);

View File

@ -3,10 +3,17 @@ import { memo, useMemo } from 'react';
type Props = {
isSelected: boolean;
isSelectedForCompare: boolean;
isHovered: boolean;
};
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
const shadow = useMemo(() => {
if (isSelectedForCompare && isHovered) {
return 'hoverSelectedForCompare';
}
if (isSelectedForCompare && !isHovered) {
return 'selectedForCompare';
}
if (isSelected && isHovered) {
return 'hoverSelected';
}
@ -17,7 +24,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
return 'hoverUnselected';
}
return undefined;
}, [isHovered, isSelected]);
}, [isHovered, isSelected, isSelectedForCompare]);
return (
<Box
className="selection-box"
@ -27,7 +34,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
bottom={0}
insetInlineStart={0}
borderRadius="base"
opacity={isSelected ? 1 : 0.7}
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
transitionProperty="common"
transitionDuration="0.1s"
pointerEvents="none"

View File

@ -0,0 +1,21 @@
import { useCallback, useMemo, useState } from 'react';
export const useBoolean = (initialValue: boolean) => {
const [isTrue, set] = useState(initialValue);
const setTrue = useCallback(() => set(true), []);
const setFalse = useCallback(() => set(false), []);
const toggle = useCallback(() => set((v) => !v), []);
const api = useMemo(
() => ({
isTrue,
set,
setTrue,
setFalse,
toggle,
}),
[isTrue, set, setTrue, setFalse, toggle]
);
return api;
};

View File

@ -1,3 +1,7 @@
export const stopPropagation = (e: React.MouseEvent) => {
e.stopPropagation();
};
export const preventDefault = (e: React.MouseEvent) => {
e.preventDefault();
};

View File

@ -1,7 +1,13 @@
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import type {
AnyInvocation,
BaseModelType,
ControlNetModelConfig,
ImageDTO,
T2IAdapterModelConfig,
} from 'services/api/types';
import { z } from 'zod';
const zId = z.string().min(1);
@ -147,7 +153,7 @@ const zBeginEndStepPct = z
const zControlAdapterBase = z.object({
id: zId,
weight: z.number().gte(0).lte(1),
weight: z.number().gte(-1).lte(2),
image: zImageWithDims.nullable(),
processedImage: zImageWithDims.nullable(),
processorConfig: zProcessorConfig.nullable(),
@ -183,7 +189,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar
export const zIPAdapterConfigV2 = z.object({
id: zId,
type: z.literal('ip_adapter'),
weight: z.number().gte(0).lte(1),
weight: z.number().gte(-1).lte(2),
method: zIPMethodV2,
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
@ -216,10 +222,7 @@ type ProcessorData<T extends ProcessorTypeV2> = {
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
buildNode(
image: ImageWithDims,
config: Extract<ProcessorConfig, { type: T }>
): Extract<Graph['nodes'][string], { type: T }>;
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
};
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);

View File

@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
const STAGE_BG_DATAURL =
export const STAGE_BG_DATAURL =
'';
const mapId = (object: { id: string }) => object.id;

View File

@ -18,7 +18,7 @@ type BaseDropData = {
id: string;
};
type CurrentImageDropData = BaseDropData & {
export type CurrentImageDropData = BaseDropData & {
actionType: 'SET_CURRENT_IMAGE';
};
@ -79,6 +79,14 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
export type SelectForCompareDropData = BaseDropData & {
actionType: 'SELECT_FOR_COMPARE';
context: {
firstImageName?: string | null;
secondImageName?: string | null;
};
};
export type TypesafeDroppableData =
| CurrentImageDropData
| ControlAdapterDropData
@ -89,7 +97,8 @@ export type TypesafeDroppableData =
| CALayerImageDropData
| IPALayerImageDropData
| RGLayerIPAdapterImageDropData
| IILayerImageDropData;
| IILayerImageDropData
| SelectForCompareDropData;
type BaseDragData = {
id: string;
@ -134,7 +143,7 @@ export type UseDraggableTypesafeReturnValue = Omit<ReturnType<typeof useOriginal
over: TypesafeOver | null;
};
export interface TypesafeActive extends Omit<Active, 'data'> {
interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}

View File

@ -1,14 +1,14 @@
import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => {
if (!overData || !active?.data.current) {
export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => {
if (!overData || !activeData) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
const { payloadType } = activeData;
if (overData.id === active.data.current.id) {
if (overData.id === activeData.id) {
return false;
}
@ -29,6 +29,8 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SELECT_FOR_COMPARE':
return payloadType === 'IMAGE_DTO';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
@ -40,7 +42,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const { imageDTO } = activeData.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
@ -49,7 +51,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
if (payloadType === 'GALLERY_SELECTION') {
// Assume all images are on the same board - this is true for the moment
const currentBoard = active.data.current.payload.boardId;
const currentBoard = activeData.payload.boardId;
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
@ -67,14 +69,14 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active:
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const { imageDTO } = activeData.payload;
const currentBoard = imageDTO.board_id ?? 'none';
return currentBoard !== 'none';
}
if (payloadType === 'GALLERY_SELECTION') {
const currentBoard = active.data.current.payload.boardId;
const currentBoard = activeData.payload.boardId;
return currentBoard !== 'none';
}

View File

@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
</Flex>
)}
{isSelectedForAutoAdd && <AutoAddIcon />}
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
<Flex
position="absolute"
bottom={0}

View File

@ -117,7 +117,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
>
{boardName}
</Flex>
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
</Flex>
</Tooltip>

View File

@ -10,6 +10,7 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@ -27,6 +28,7 @@ import {
PiDownloadSimpleBold,
PiFlowArrowBold,
PiFoldersBold,
PiImagesBold,
PiPlantBold,
PiQuotesBold,
PiShareFatBold,
@ -44,6 +46,7 @@ type SingleSelectionMenuItemsProps = {
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { imageDTO } = props;
const optimalDimension = useAppSelector(selectOptimalDimension);
const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isCanvasEnabled = useFeatureStatus('canvas');
@ -117,6 +120,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
downloadImage(imageDTO.image_url, imageDTO.image_name);
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
const handleSelectImageForCompare = useCallback(() => {
dispatch(imageToCompareChanged(imageDTO));
}, [dispatch, imageDTO]);
return (
<>
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
@ -130,6 +137,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem icon={<PiImagesBold />} isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}>
{t('gallery.selectForCompare')}
</MenuItem>
<MenuDivider />
<MenuItem
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}

View File

@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
@ -46,6 +46,7 @@ const GalleryImage = (props: HoverableImageProps) => {
const { t } = useTranslation();
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
@ -105,6 +106,7 @@ const GalleryImage = (props: HoverableImageProps) => {
const onDoubleClick = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
dispatch(imageToCompareChanged(null));
}, [dispatch]);
const handleMouseOut = useCallback(() => {
@ -152,6 +154,7 @@ const GalleryImage = (props: HoverableImageProps) => {
imageDTO={imageDTO}
draggableData={draggableData}
isSelected={isSelected}
isSelectedForCompare={isSelectedForCompare}
minSize={0}
imageSx={imageSx}
isDropDisabled={true}

View File

@ -0,0 +1,140 @@
import {
Button,
ButtonGroup,
Flex,
Icon,
IconButton,
Kbd,
ListItem,
Tooltip,
UnorderedList,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeChanged,
comparisonModeCycled,
imageToCompareChanged,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi';
export const CompareToolbar = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
const setComparisonModeSlider = useCallback(() => {
dispatch(comparisonModeChanged('slider'));
}, [dispatch]);
const setComparisonModeSideBySide = useCallback(() => {
dispatch(comparisonModeChanged('side-by-side'));
}, [dispatch]);
const setComparisonModeHover = useCallback(() => {
dispatch(comparisonModeChanged('hover'));
}, [dispatch]);
const swapImages = useCallback(() => {
dispatch(comparedImagesSwapped());
}, [dispatch]);
useHotkeys('c', swapImages, [swapImages]);
const toggleComparisonFit = useCallback(() => {
dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain'));
}, [dispatch, comparisonFit]);
const exitCompare = useCallback(() => {
dispatch(imageToCompareChanged(null));
}, [dispatch]);
useHotkeys('esc', exitCompare, [exitCompare]);
const nextMode = useCallback(() => {
dispatch(comparisonModeCycled());
}, [dispatch]);
useHotkeys('m', nextMode, [nextMode]);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<IconButton
icon={<PiSwapBold />}
aria-label={`${t('gallery.swapImages')} (C)`}
tooltip={`${t('gallery.swapImages')} (C)`}
onClick={swapImages}
/>
{comparisonMode !== 'side-by-side' && (
<IconButton
aria-label={t('gallery.stretchToFit')}
tooltip={t('gallery.stretchToFit')}
onClick={toggleComparisonFit}
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
variant="outline"
icon={<PiArrowsOutBold />}
/>
)}
</Flex>
</Flex>
<Flex flex={1} gap={4} justifyContent="center">
<ButtonGroup variant="outline">
<Button
flexShrink={0}
onClick={setComparisonModeSlider}
colorScheme={comparisonMode === 'slider' ? 'invokeBlue' : 'base'}
>
{t('gallery.slider')}
</Button>
<Button
flexShrink={0}
onClick={setComparisonModeSideBySide}
colorScheme={comparisonMode === 'side-by-side' ? 'invokeBlue' : 'base'}
>
{t('gallery.sideBySide')}
</Button>
<Button
flexShrink={0}
onClick={setComparisonModeHover}
colorScheme={comparisonMode === 'hover' ? 'invokeBlue' : 'base'}
>
{t('gallery.hover')}
</Button>
</ButtonGroup>
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" alignItems="center">
<Tooltip label={<CompareHelp />}>
<Flex alignItems="center">
<Icon boxSize={8} color="base.500" as={PiQuestion} lineHeight={0} />
</Flex>
</Tooltip>
<IconButton
icon={<PiXBold />}
aria-label={`${t('gallery.exitCompare')} (Esc)`}
tooltip={`${t('gallery.exitCompare')} (Esc)`}
onClick={exitCompare}
/>
</Flex>
</Flex>
</Flex>
);
});
CompareToolbar.displayName = 'CompareToolbar';
const CompareHelp = () => {
return (
<UnorderedList>
<ListItem>
<Trans i18nKey="gallery.compareHelp1" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp2" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp3" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp4" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
</UnorderedList>
);
};

View File

@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import type { TypesafeDraggableData } from 'features/dnd/types';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
@ -22,21 +22,7 @@ const selectLastSelectedImageName = createSelector(
(lastSelectedImage) => lastSelectedImage?.image_name
);
type Props = {
isDragDisabled?: boolean;
isDropDisabled?: boolean;
withNextPrevButtons?: boolean;
withMetadata?: boolean;
alwaysShowProgress?: boolean;
};
const CurrentImagePreview = ({
isDragDisabled = false,
isDropDisabled = false,
withNextPrevButtons = true,
withMetadata = true,
alwaysShowProgress = false,
}: Props) => {
const CurrentImagePreview = () => {
const { t } = useTranslation();
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
const imageName = useAppSelector(selectLastSelectedImageName);
@ -55,14 +41,6 @@ const CurrentImagePreview = ({
}
}, [imageDTO]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
}),
[]
);
// Show and hide the next/prev buttons on mouse move
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
const timeoutId = useRef(0);
@ -86,30 +64,27 @@ const CurrentImagePreview = ({
justifyContent="center"
position="relative"
>
{hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? (
{hasDenoiseProgress && shouldShowProgressInViewer ? (
<ProgressImage />
) : (
<IAIDndImage
imageDTO={imageDTO}
droppableData={droppableData}
draggableData={draggableData}
isDragDisabled={isDragDisabled}
isDropDisabled={isDropDisabled}
isDropDisabled={true}
isUploadDisabled={true}
fitContainer
useThumbailFallback
dropLabel={t('gallery.setCurrentImage')}
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
dataTestId="image-preview"
/>
)}
{shouldShowImageDetails && imageDTO && withMetadata && (
{shouldShowImageDetails && imageDTO && (
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
<ImageMetadataViewer image={imageDTO} />
</Box>
)}
<AnimatePresence>
{withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && (
{shouldShowNextPrevButtons && imageDTO && (
<Box
as={motion.div}
key="nextPrevButtons"

View File

@ -0,0 +1,41 @@
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesBold } from 'react-icons/pi';
type Props = {
containerDims: Dimensions;
};
export const ImageComparison = memo(({ containerDims }: Props) => {
const { t } = useTranslation();
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
if (!firstImage || !secondImage) {
// Should rarely/never happen - we don't render this component unless we have images to compare
return <IAINoContentFallback label={t('gallery.selectAnImageToCompare')} icon={PiImagesBold} />;
}
if (comparisonMode === 'slider') {
return <ImageComparisonSlider containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
}
if (comparisonMode === 'side-by-side') {
return (
<ImageComparisonSideBySide containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />
);
}
if (comparisonMode === 'hover') {
return <ImageComparisonHover containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
}
});
ImageComparison.displayName = 'ImageComparison';

View File

@ -0,0 +1,47 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDroppable from 'common/components/IAIDroppable';
import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { selectComparisonImages } from './common';
const setCurrentImageDropData: CurrentImageDropData = {
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
};
export const ImageComparisonDroppable = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
const selectForCompareDropData = useMemo<SelectForCompareDropData>(
() => ({
id: 'image-comparison',
actionType: 'SELECT_FOR_COMPARE',
context: {
firstImageName: firstImage?.image_name,
secondImageName: secondImage?.image_name,
},
}),
[firstImage?.image_name, secondImage?.image_name]
);
if (!imageViewer.isOpen) {
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable data={setCurrentImageDropData} dropLabel={t('gallery.openInViewer')} />
</Flex>
);
}
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable data={selectForCompareDropData} dropLabel={t('gallery.selectForCompare')} />
</Flex>
);
});
ImageComparisonDroppable.displayName = 'ImageComparisonDroppable';

View File

@ -0,0 +1,117 @@
import { Box, Flex, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useMemo, useRef } from 'react';
import type { ComparisonProps } from './common';
import { fitDimsToContainer, getSecondImageDims } from './common';
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
const imageContainerRef = useRef<HTMLDivElement>(null);
const mouseOver = useBoolean(false);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]
);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex
id="image-comparison-wrapper"
w="full"
h="full"
maxW="full"
maxH="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<Box
ref={imageContainerRef}
position="relative"
id="image-comparison-hover-image-container"
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
userSelect="none"
overflow="hidden"
borderRadius="base"
>
<Image
id="image-comparison-hover-first-image"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
objectFit="cover"
objectPosition="top left"
/>
<ImageComparisonLabel type="first" opacity={mouseOver.isTrue ? 0 : 1} />
<Box
id="image-comparison-hover-second-image-container"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
overflow="hidden"
opacity={mouseOver.isTrue ? 1 : 0}
transitionDuration="0.2s"
transitionProperty="common"
>
<Box
id="image-comparison-hover-bg"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
backgroundImage={STAGE_BG_DATAURL}
backgroundRepeat="repeat"
opacity={0.2}
/>
<Image
position="relative"
id="image-comparison-hover-second-image"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
w={compareImageDims.width}
h={compareImageDims.height}
maxW={fittedDims.width}
maxH={fittedDims.height}
objectFit={comparisonFit}
objectPosition="top left"
/>
<ImageComparisonLabel type="second" opacity={mouseOver.isTrue ? 1 : 0} />
</Box>
<Box
id="image-comparison-hover-interaction-overlay"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
onMouseOver={mouseOver.setTrue}
onMouseOut={mouseOver.setFalse}
onContextMenu={preventDefault}
userSelect="none"
/>
</Box>
</Flex>
</Flex>
);
});
ImageComparisonHover.displayName = 'ImageComparisonHover';

View File

@ -0,0 +1,33 @@
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { DROP_SHADOW } from './common';
type Props = TextProps & {
type: 'first' | 'second';
};
export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => {
const { t } = useTranslation();
return (
<Text
position="absolute"
bottom={4}
insetInlineEnd={type === 'first' ? undefined : 4}
insetInlineStart={type === 'first' ? 4 : undefined}
textOverflow="clip"
whiteSpace="nowrap"
filter={DROP_SHADOW}
color="base.50"
transitionDuration="0.2s"
transitionProperty="common"
{...rest}
>
{type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')}
</Text>
);
});
ImageComparisonLabel.displayName = 'ImageComparisonLabel';

View File

@ -0,0 +1,70 @@
import { Flex, Image } from '@invoke-ai/ui-library';
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo, useCallback, useRef } from 'react';
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
import { Panel, PanelGroup } from 'react-resizable-panels';
export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => {
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const onDoubleClickHandle = useCallback(() => {
if (!panelGroupRef.current) {
return;
}
panelGroupRef.current.setLayout([50, 50]);
}, []);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex w="full" h="full" maxW="full" maxH="full" position="absolute" alignItems="center" justifyContent="center">
<PanelGroup ref={panelGroupRef} direction="horizontal" id="image-comparison-side-by-side">
<Panel minSize={20}>
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={firstImage.width / firstImage.height}>
<Image
id="image-comparison-side-by-side-first-image"
w={firstImage.width}
h={firstImage.height}
maxW="full"
maxH="full"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
objectFit="contain"
borderRadius="base"
/>
<ImageComparisonLabel type="first" />
</Flex>
</Flex>
</Panel>
<ResizeHandle
id="image-comparison-side-by-side-handle"
onDoubleClick={onDoubleClickHandle}
orientation="vertical"
/>
<Panel minSize={20}>
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={secondImage.width / secondImage.height}>
<Image
id="image-comparison-side-by-side-first-image"
w={secondImage.width}
h={secondImage.height}
maxW="full"
maxH="full"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
objectFit="contain"
borderRadius="base"
/>
<ImageComparisonLabel type="second" />
</Flex>
</Flex>
</Panel>
</PanelGroup>
</Flex>
</Flex>
);
});
ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide';

View File

@ -0,0 +1,215 @@
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import type { ComparisonProps } from './common';
import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common';
const INITIAL_POS = '50%';
const HANDLE_WIDTH = 2;
const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`;
const HANDLE_HITBOX = 20;
const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX);
// How wide the first image is
const [width, setWidth] = useState(INITIAL_POS);
const handleRef = useRef<HTMLDivElement>(null);
// To manage aspect ratios, we need to know the size of the container
const imageContainerRef = useRef<HTMLDivElement>(null);
// To keep things smooth, we use RAF to update the handle position & gate it to 60fps
const rafRef = useRef<number | null>(null);
const lastMoveTimeRef = useRef<number>(0);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]
);
const updateHandlePos = useCallback((clientX: number) => {
if (!handleRef.current || !imageContainerRef.current) {
return;
}
lastMoveTimeRef.current = performance.now();
const { x, width } = imageContainerRef.current.getBoundingClientRect();
const rawHandlePos = ((clientX - x) * 100) / width;
const handleWidthPct = (HANDLE_WIDTH * 100) / width;
const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos));
setWidth(`${newHandlePos}%`);
setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`);
}, []);
const onMouseMove = useCallback(
(e: MouseEvent) => {
if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) {
rafRef.current = window.requestAnimationFrame(() => {
updateHandlePos(e.clientX);
rafRef.current = null;
});
}
},
[updateHandlePos]
);
const onMouseUp = useCallback(() => {
window.removeEventListener('mousemove', onMouseMove);
}, [onMouseMove]);
const onMouseDown = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
// Update the handle position immediately on click
updateHandlePos(e.clientX);
window.addEventListener('mouseup', onMouseUp, { once: true });
window.addEventListener('mousemove', onMouseMove);
},
[onMouseMove, onMouseUp, updateHandlePos]
);
useEffect(
() => () => {
if (rafRef.current !== null) {
cancelAnimationFrame(rafRef.current);
}
},
[]
);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex
id="image-comparison-wrapper"
w="full"
h="full"
maxW="full"
maxH="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<Box
ref={imageContainerRef}
position="relative"
id="image-comparison-image-container"
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
userSelect="none"
overflow="hidden"
borderRadius="base"
>
<Box
id="image-comparison-bg"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
backgroundImage={STAGE_BG_DATAURL}
backgroundRepeat="repeat"
opacity={0.2}
/>
<Image
position="relative"
id="image-comparison-second-image"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
w={compareImageDims.width}
h={compareImageDims.height}
maxW={fittedDims.width}
maxH={fittedDims.height}
objectFit={comparisonFit}
objectPosition="top left"
/>
<ImageComparisonLabel type="second" />
<Box
id="image-comparison-first-image-container"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
w={width}
overflow="hidden"
>
<Image
id="image-comparison-first-image"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
w={fittedDims.width}
h={fittedDims.height}
objectFit="cover"
objectPosition="top left"
/>
<ImageComparisonLabel type="first" />
</Box>
<Flex
id="image-comparison-handle"
ref={handleRef}
position="absolute"
top={0}
bottom={0}
left={left}
w={HANDLE_HITBOX_PX}
cursor="ew-resize"
filter={DROP_SHADOW}
opacity={0.8}
color="base.50"
>
<Box
id="image-comparison-handle-divider"
w={HANDLE_WIDTH_PX}
h="full"
bg="currentColor"
shadow="dark-lg"
position="absolute"
top={0}
left={HANDLE_INNER_LEFT_PX}
/>
<Flex
id="image-comparison-handle-icons"
gap={4}
position="absolute"
left="50%"
top="50%"
transform="translate(-50%, 0)"
filter={DROP_SHADOW}
>
<Icon as={PiCaretLeftBold} />
<Icon as={PiCaretRightBold} />
</Flex>
</Flex>
<Box
id="image-comparison-interaction-overlay"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
onMouseDown={onMouseDown}
onContextMenu={preventDefault}
userSelect="none"
cursor="ew-resize"
/>
</Box>
</Flex>
</Flex>
);
});
ImageComparisonSlider.displayName = 'ImageComparisonSlider';

View File

@ -1,36 +1,16 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Box, Flex } from '@invoke-ai/ui-library';
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview';
import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison';
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
import { memo } from 'react';
import { useMeasure } from 'react-use';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
import { ViewerToggleMenu } from './ViewerToggleMenu';
const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows'];
import { useImageViewer } from './useImageViewer';
export const ImageViewer = memo(() => {
const { isOpen, onToggle, onClose } = useImageViewer();
const activeTabName = useAppSelector(activeTabNameSelector);
const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]);
const shouldShowViewer = useMemo(() => {
if (!isViewerEnabled) {
return false;
}
return isOpen;
}, [isOpen, isViewerEnabled]);
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);
if (!shouldShowViewer) {
return null;
}
const imageViewer = useImageViewer();
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
return (
<Flex
@ -46,25 +26,13 @@ export const ImageViewer = memo(() => {
rowGap={4}
alignItems="center"
justifyContent="center"
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
>
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
<ViewerToggleMenu />
</Flex>
</Flex>
</Flex>
<CurrentImagePreview />
{imageViewer.isComparing && <CompareToolbar />}
{!imageViewer.isComparing && <ViewerToolbar />}
<Box ref={containerRef} w="full" h="full">
{!imageViewer.isComparing && <CurrentImagePreview />}
{imageViewer.isComparing && <ImageComparison containerDims={containerDims} />}
</Box>
</Flex>
);
});

View File

@ -1,45 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { memo } from 'react';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const ImageViewerWorkflows = memo(() => {
return (
<Flex
layerStyle="first"
borderRadius="base"
position="absolute"
flexDirection="column"
top={0}
right={0}
bottom={0}
left={0}
p={2}
rowGap={4}
alignItems="center"
justifyContent="center"
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
>
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" />
</Flex>
</Flex>
<CurrentImagePreview />
</Flex>
);
});
ImageViewerWorkflows.displayName = 'ImageViewerWorkflows';

View File

@ -9,33 +9,35 @@ import {
PopoverTrigger,
Text,
} from '@invoke-ai/ui-library';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
import { useImageViewer } from './useImageViewer';
export const ViewerToggleMenu = () => {
const { t } = useTranslation();
const { isOpen, onClose, onOpen } = useImageViewer();
const imageViewer = useImageViewer();
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
return (
<Popover isLazy>
<PopoverTrigger>
<Button variant="outline" data-testid="toggle-viewer-menu-button">
<Button variant="outline" data-testid="toggle-viewer-menu-button" pointerEvents="auto">
<Flex gap={3} w="full" alignItems="center">
{isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
<Text fontSize="md">{isOpen ? t('common.viewing') : t('common.editing')}</Text>
{imageViewer.isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
<Text fontSize="md">{imageViewer.isOpen ? t('common.viewing') : t('common.editing')}</Text>
<Icon as={PiCaretDownBold} />
</Flex>
</Button>
</PopoverTrigger>
<PopoverContent p={2}>
<PopoverContent p={2} pointerEvents="auto">
<PopoverArrow />
<PopoverBody>
<Flex flexDir="column">
<Button onClick={onOpen} variant="ghost" h="auto" w="auto" p={2}>
<Button onClick={imageViewer.onOpen} variant="ghost" h="auto" w="auto" p={2}>
<Flex gap={2} w="full">
<Icon as={PiCheckBold} visibility={isOpen ? 'visible' : 'hidden'} />
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'visible' : 'hidden'} />
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Text fontWeight="semibold" color="base.100">
{t('common.viewing')}
@ -46,9 +48,9 @@ export const ViewerToggleMenu = () => {
</Flex>
</Flex>
</Button>
<Button onClick={onClose} variant="ghost" h="auto" w="auto" p={2}>
<Button onClick={imageViewer.onClose} variant="ghost" h="auto" w="auto" p={2}>
<Flex gap={2} w="full">
<Icon as={PiCheckBold} visibility={isOpen ? 'hidden' : 'visible'} />
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'hidden' : 'visible'} />
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Text fontWeight="semibold" color="base.100">
{t('common.editing')}

View File

@ -0,0 +1,33 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import CurrentImageButtons from './CurrentImageButtons';
import { ViewerToggleMenu } from './ViewerToggleMenu';
export const ViewerToolbar = memo(() => {
const tab = useAppSelector(activeTabNameSelector);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
{tab !== 'workflows' && <ViewerToggleMenu />}
</Flex>
</Flex>
</Flex>
);
});
ViewerToolbar.displayName = 'ViewerToolbar';

View File

@ -0,0 +1,64 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import type { ComparisonFit } from 'features/gallery/store/types';
import type { ImageDTO } from 'services/api/types';
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export type ComparisonProps = {
firstImage: ImageDTO;
secondImage: ImageDTO;
containerDims: Dimensions;
};
export const fitDimsToContainer = (containerDims: Dimensions, imageDims: Dimensions): Dimensions => {
// Fall back to the image's dimensions if the container has no dimensions
if (containerDims.width === 0 || containerDims.height === 0) {
return { width: imageDims.width, height: imageDims.height };
}
// Fall back to the image's dimensions if the image fits within the container
if (imageDims.width <= containerDims.width && imageDims.height <= containerDims.height) {
return { width: imageDims.width, height: imageDims.height };
}
const targetAspectRatio = containerDims.width / containerDims.height;
const imageAspectRatio = imageDims.width / imageDims.height;
let width: number;
let height: number;
if (imageAspectRatio > targetAspectRatio) {
// Image is wider than container's aspect ratio
width = containerDims.width;
height = width / imageAspectRatio;
} else {
// Image is taller than container's aspect ratio
height = containerDims.height;
width = height * imageAspectRatio;
}
return { width, height };
};
/**
* Gets the dimensions of the second image in a comparison based on the comparison fit mode.
*/
export const getSecondImageDims = (
comparisonFit: ComparisonFit,
fittedDims: Dimensions,
firstImageDims: Dimensions,
secondImageDims: Dimensions
): Dimensions => {
const width =
comparisonFit === 'fill' ? fittedDims.width : (fittedDims.width * secondImageDims.width) / firstImageDims.width;
const height =
comparisonFit === 'fill' ? fittedDims.height : (fittedDims.height * secondImageDims.height) / firstImageDims.height;
return { width, height };
};
export const selectComparisonImages = createMemoizedSelector(selectGallerySlice, (gallerySlice) => {
const firstImage = gallerySlice.selection.slice(-1)[0] ?? null;
const secondImage = gallerySlice.imageToCompare;
return { firstImage, secondImage };
});

View File

@ -0,0 +1,31 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
export const useImageViewer = () => {
const dispatch = useAppDispatch();
const isComparing = useAppSelector((s) => s.gallery.imageToCompare !== null);
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
const onClose = useCallback(() => {
if (isComparing && isOpen) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(isImageViewerOpenChanged(false));
}
}, [dispatch, isComparing, isOpen]);
const onOpen = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
}, [dispatch]);
const onToggle = useCallback(() => {
if (isComparing && isOpen) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(isImageViewerOpenChanged(!isOpen));
}
}, [dispatch, isComparing, isOpen]);
return { isOpen, onOpen, onClose, onToggle, isComparing };
};

View File

@ -1,22 +0,0 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
export const useImageViewer = () => {
const dispatch = useAppDispatch();
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
const onClose = useCallback(() => {
dispatch(isImageViewerOpenChanged(false));
}, [dispatch]);
const onOpen = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
}, [dispatch]);
const onToggle = useCallback(() => {
dispatch(isImageViewerOpenChanged(!isOpen));
}, [dispatch, isOpen]);
return { isOpen, onOpen, onClose, onToggle };
};

View File

@ -14,7 +14,7 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
const NextPrevImageButtons = () => {
const { t } = useTranslation();
const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const {
areMoreImagesAvailable,
@ -30,7 +30,7 @@ const NextPrevImageButtons = () => {
aria-label={t('accessibility.previousImage')}
icon={<PiCaretLeftBold size={64} />}
variant="unstyled"
onClick={handleLeftImage}
onClick={prevImage}
boxSize={16}
sx={nextPrevButtonStyles}
/>
@ -42,7 +42,7 @@ const NextPrevImageButtons = () => {
aria-label={t('accessibility.nextImage')}
icon={<PiCaretRightBold size={64} />}
variant="unstyled"
onClick={handleRightImage}
onClick={nextImage}
boxSize={16}
sx={nextPrevButtonStyles}
/>

View File

@ -27,16 +27,16 @@ export const useGalleryHotkeys = () => {
useGalleryNavigation();
useHotkeys(
'left',
() => {
canNavigateGallery && handleLeftImage();
['left', 'alt+left'],
(e) => {
canNavigateGallery && handleLeftImage(e.altKey);
},
[handleLeftImage, canNavigateGallery]
);
useHotkeys(
'right',
() => {
['right', 'alt+right'],
(e) => {
if (!canNavigateGallery) {
return;
}
@ -45,29 +45,29 @@ export const useGalleryHotkeys = () => {
return;
}
if (!isOnLastImage) {
handleRightImage();
handleRightImage(e.altKey);
}
},
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
);
useHotkeys(
'up',
() => {
handleUpImage();
['up', 'alt+up'],
(e) => {
handleUpImage(e.altKey);
},
{ preventDefault: true },
[handleUpImage]
);
useHotkeys(
'down',
() => {
['down', 'alt+down'],
(e) => {
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
handleDownImage();
handleDownImage(e.altKey);
},
{ preventDefault: true },
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]

View File

@ -1,11 +1,11 @@
import { useAltModifier } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { getIsVisible } from 'features/gallery/util/getIsVisible';
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
import { clamp } from 'lodash-es';
@ -106,10 +106,12 @@ const getImageFuncs = {
};
type UseGalleryNavigationReturn = {
handleLeftImage: () => void;
handleRightImage: () => void;
handleUpImage: () => void;
handleDownImage: () => void;
handleLeftImage: (alt?: boolean) => void;
handleRightImage: (alt?: boolean) => void;
handleUpImage: (alt?: boolean) => void;
handleDownImage: (alt?: boolean) => void;
prevImage: () => void;
nextImage: () => void;
isOnFirstImage: boolean;
isOnLastImage: boolean;
areImagesBelowCurrent: boolean;
@ -123,7 +125,15 @@ type UseGalleryNavigationReturn = {
*/
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
const dispatch = useAppDispatch();
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const alt = useAltModifier();
const lastSelectedImage = useAppSelector((s) => {
const lastSelected = s.gallery.selection.slice(-1)[0] ?? null;
if (alt) {
return s.gallery.imageToCompare ?? lastSelected;
} else {
return lastSelected;
}
});
const {
queryResult: { data },
} = useGalleryImages();
@ -136,7 +146,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
}, [lastSelectedImage, data]);
const handleNavigation = useCallback(
(direction: 'left' | 'right' | 'up' | 'down') => {
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
if (!data) {
return;
}
@ -144,10 +154,14 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
if (!image || index === lastSelectedImageIndex) {
return;
}
if (alt) {
dispatch(imageToCompareChanged(image));
} else {
dispatch(imageSelected(image));
}
scrollToImage(image.image_name, index);
},
[dispatch, lastSelectedImageIndex, data]
[data, lastSelectedImageIndex, dispatch]
);
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
@ -162,21 +176,41 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
}, [lastSelectedImageIndex, loadedImagesCount]);
const handleLeftImage = useCallback(() => {
handleNavigation('left');
}, [handleNavigation]);
const handleLeftImage = useCallback(
(alt?: boolean) => {
handleNavigation('left', alt);
},
[handleNavigation]
);
const handleRightImage = useCallback(() => {
handleNavigation('right');
}, [handleNavigation]);
const handleRightImage = useCallback(
(alt?: boolean) => {
handleNavigation('right', alt);
},
[handleNavigation]
);
const handleUpImage = useCallback(() => {
handleNavigation('up');
}, [handleNavigation]);
const handleUpImage = useCallback(
(alt?: boolean) => {
handleNavigation('up', alt);
},
[handleNavigation]
);
const handleDownImage = useCallback(() => {
handleNavigation('down');
}, [handleNavigation]);
const handleDownImage = useCallback(
(alt?: boolean) => {
handleNavigation('down', alt);
},
[handleNavigation]
);
const nextImage = useCallback(() => {
handleRightImage();
}, [handleRightImage]);
const prevImage = useCallback(() => {
handleLeftImage();
}, [handleLeftImage]);
return {
handleLeftImage,
@ -186,5 +220,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
isOnFirstImage,
isOnLastImage,
areImagesBelowCurrent,
nextImage,
prevImage,
};
};

View File

@ -36,6 +36,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
shiftKey: e.shiftKey,
ctrlKey: e.ctrlKey,
metaKey: e.metaKey,
altKey: e.altKey,
})
);
},

View File

@ -6,7 +6,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { BoardId, GalleryState, GalleryView } from './types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
const initialGalleryState: GalleryState = {
@ -22,6 +22,9 @@ const initialGalleryState: GalleryState = {
limit: INITIAL_IMAGE_LIMIT,
offset: 0,
isImageViewerOpen: true,
imageToCompare: null,
comparisonMode: 'slider',
comparisonFit: 'fill',
};
export const gallerySlice = createSlice({
@ -34,6 +37,28 @@ export const gallerySlice = createSlice({
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
state.selection = uniqBy(action.payload, (i) => i.image_name);
},
imageToCompareChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.imageToCompare = action.payload;
if (action.payload) {
state.isImageViewerOpen = true;
}
},
comparisonModeChanged: (state, action: PayloadAction<ComparisonMode>) => {
state.comparisonMode = action.payload;
},
comparisonModeCycled: (state) => {
switch (state.comparisonMode) {
case 'slider':
state.comparisonMode = 'side-by-side';
break;
case 'side-by-side':
state.comparisonMode = 'hover';
break;
case 'hover':
state.comparisonMode = 'slider';
break;
}
},
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitch = action.payload;
},
@ -79,6 +104,16 @@ export const gallerySlice = createSlice({
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isImageViewerOpen = action.payload;
},
comparedImagesSwapped: (state) => {
if (state.imageToCompare) {
const oldSelection = state.selection;
state.selection = [state.imageToCompare];
state.imageToCompare = oldSelection[0] ?? null;
}
},
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
state.comparisonFit = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
@ -117,6 +152,11 @@ export const {
moreImagesLoaded,
alwaysShowImageSizeBadgeChanged,
isImageViewerOpenChanged,
imageToCompareChanged,
comparisonModeChanged,
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeCycled,
} = gallerySlice.actions;
const isAnyBoardDeleted = isAnyOf(
@ -138,5 +178,13 @@ export const galleryPersistConfig: PersistConfig<GalleryState> = {
name: gallerySlice.name,
initialState: initialGalleryState,
migrate: migrateGalleryState,
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'offset', 'limit', 'isImageViewerOpen'],
persistDenylist: [
'selection',
'selectedBoardId',
'galleryView',
'offset',
'limit',
'isImageViewerOpen',
'imageToCompare',
],
};

View File

@ -7,6 +7,8 @@ export const IMAGE_LIMIT = 20;
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
export type ComparisonFit = 'contain' | 'fill';
export type GalleryState = {
selection: ImageDTO[];
@ -20,5 +22,8 @@ export type GalleryState = {
offset: number;
limit: number;
alwaysShowImageSizeBadge: boolean;
imageToCompare: ImageDTO | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
isImageViewerOpen: boolean;
};

View File

@ -19,7 +19,7 @@ import {
redo,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { $flow, $needsFit } from 'features/nodes/store/reactFlowInstance';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
@ -68,6 +68,7 @@ export const Flow = memo(() => {
const nodes = useAppSelector((s) => s.nodes.present.nodes);
const edges = useAppSelector((s) => s.nodes.present.edges);
const viewport = useStore($viewport);
const needsFit = useStore($needsFit);
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
@ -92,8 +93,16 @@ export const Flow = memo(() => {
const onNodesChange: OnNodesChange = useCallback(
(nodeChanges) => {
dispatch(nodesChanged(nodeChanges));
const flow = $flow.get();
if (!flow) {
return;
}
if (needsFit) {
$needsFit.set(false);
flow.fitView();
}
},
[dispatch]
[dispatch, needsFit]
);
const onEdgesChange: OnEdgesChange = useCallback(

View File

@ -15,27 +15,20 @@ const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
// const shouldShowFieldTypeLegend = useAppSelector(
// (s) => s.nodes.present.shouldShowFieldTypeLegend
// );
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
const handleClickedZoomIn = useCallback(() => {
zoomIn();
zoomIn({ duration: 300 });
}, [zoomIn]);
const handleClickedZoomOut = useCallback(() => {
zoomOut();
zoomOut({ duration: 300 });
}, [zoomOut]);
const handleClickedFitView = useCallback(() => {
fitView();
fitView({ duration: 300 });
}, [fitView]);
// const handleClickedToggleFieldTypeLegend = useCallback(() => {
// dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
// }, [shouldShowFieldTypeLegend, dispatch]);
const handleClickedToggleMiniMapPanel = useCallback(() => {
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
}, [shouldShowMinimapPanel, dispatch]);

View File

@ -1,9 +1,7 @@
import 'reactflow/dist/style.css';
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import QueueControls from 'features/queue/components/QueueControls';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
@ -21,14 +19,8 @@ import { WorkflowName } from './WorkflowName';
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
return {
mode: workflow.mode,
};
});
const NodeEditorPanelGroup = () => {
const { mode } = useAppSelector(selector);
const mode = useAppSelector((s) => s.workflow.mode);
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const panelStorage = usePanelStorage();

View File

@ -1,20 +1,12 @@
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
import { ModeToggle } from './ModeToggle';
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
return {
mode: workflow.mode,
};
});
export const WorkflowMenu = () => {
const { mode } = useAppSelector(selector);
const mode = useAppSelector((s) => s.workflow.mode);
return (
<Flex gap="2" alignItems="center">

View File

@ -11,8 +11,7 @@ import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { ImageOutput } from 'services/api/types';
import type { AnyResult } from 'services/events/types';
import type { AnyInvocationOutput, ImageOutput } from 'services/api/types';
import ImageOutputPreview from './outputs/ImageOutputPreview';
@ -66,4 +65,4 @@ const InspectorOutputsTab = () => {
export default memo(InspectorOutputsTab);
const getKey = (result: AnyResult, i: number) => `${result.type}-${i}`;
const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`;

View File

@ -2,3 +2,4 @@ import { atom } from 'nanostores';
import type { ReactFlowInstance } from 'reactflow';
export const $flow = atom<ReactFlowInstance | null>(null);
export const $needsFit = atom<boolean>(true);

View File

@ -144,5 +144,4 @@ const zImageOutput = z.object({
type: z.literal('image_output'),
});
export type ImageOutput = z.infer<typeof zImageOutput>;
export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success;
// #endregion

View File

@ -1,8 +1,7 @@
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { Graph } from 'services/api/types';
import type { AnyInvocation } from 'services/events/types';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
/**

View File

@ -1,6 +1,7 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ImageViewerWorkflows } from 'features/gallery/components/ImageViewer/ImageViewerWorkflows';
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import NodeEditor from 'features/nodes/components/NodeEditor';
import { memo } from 'react';
import { ReactFlowProvider } from 'reactflow';
@ -10,7 +11,8 @@ const NodesTab = () => {
if (mode === 'view') {
return (
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
<ImageViewerWorkflows />
<ImageViewer />
<ImageComparisonDroppable />
</Box>
);
}

View File

@ -1,13 +1,17 @@
import { Box } from '@invoke-ai/ui-library';
import { ControlLayersEditor } from 'features/controlLayers/components/ControlLayersEditor';
import { ImageComparisonDroppable } from 'features/gallery/components/ImageViewer/ImageComparisonDroppable';
import { ImageViewer } from 'features/gallery/components/ImageViewer/ImageViewer';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo } from 'react';
const TextToImageTab = () => {
const imageViewer = useImageViewer();
return (
<Box layerStyle="first" position="relative" w="full" h="full" p={2} borderRadius="base">
<ControlLayersEditor />
<ImageViewer />
{imageViewer.isOpen && <ImageViewer />}
<ImageComparisonDroppable />
</Box>
);
};

View File

@ -41,7 +41,7 @@ const UnifiedCanvasTab = () => {
>
<IAICanvasToolbar />
<IAICanvas />
{isValidDrop(droppableData, active) && (
{isValidDrop(droppableData, active?.data.current) && (
<IAIDropOverlay isOver={isOver} label={t('toast.setCanvasInitialImage')} />
)}
</Flex>

File diff suppressed because one or more lines are too long

View File

@ -122,7 +122,6 @@ export type ModelInstallStatus = S['InstallStatus'];
// Graphs
export type Graph = S['Graph'];
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
export type GraphExecutionState = S['GraphExecutionState'];
export type Batch = S['Batch'];
export type SessionQueueItemDTO = S['SessionQueueItemDTO'];
export type WorkflowRecordOrderBy = S['WorkflowRecordOrderBy'];
@ -132,14 +131,14 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
type KeysOfUnion<T> = T extends T ? keyof T : never;
export type AnyInvocation = Exclude<
Graph['nodes'][string],
NonNullable<S['Graph']['nodes']>[string],
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
>;
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
export type AnyInvocationIncMetadata = NonNullable<S['Graph']['nodes']>[string];
export type InvocationType = AnyInvocation['type'];
type InvocationOutputMap = S['InvocationOutputMap'];
type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
// export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];

View File

@ -1,21 +1,12 @@
import type { Graph, GraphExecutionState, S } from 'services/api/types';
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
import type { S } from 'services/api/types';
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation };
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & {
invocation: AnyInvocation;
};
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
result: AnyResult;
invocation: AnyInvocation;
};
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];

View File

@ -55,10 +55,10 @@ dependencies = [
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0",
"fastapi==0.110.0",
"fastapi==0.111.0",
"huggingface-hub==0.23.1",
"pydantic-settings==2.2.1",
"pydantic==2.6.3",
"pydantic==2.7.2",
"python-socketio==5.11.1",
"uvicorn[standard]==0.28.0",

View File

@ -7,9 +7,10 @@ def main():
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import custom_openapi
from invokeai.app.api_app import app
from invokeai.app.util.custom_openapi import get_openapi_func
schema = custom_openapi()
schema = get_openapi_func(app)()
json.dump(schema, sys.stdout, indent=2)

View File

@ -1,5 +1,6 @@
import pytest
from pydantic import TypeAdapter
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -713,4 +714,4 @@ def test_iterate_accepts_collection():
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
_ = Graph.model_json_schema()
models_json_schema([(Graph, "serialization")])