diff --git a/Makefile b/Makefile index 7344b2e8d2..e858a89e2b 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ help: @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "installer-zip Build the installer .zip file for the current version" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" + @echo "openapi Generate the OpenAPI schema for the app, outputting to stdout" # Runs ruff, fixing any safely-fixable errors and formatting ruff: @@ -70,3 +71,6 @@ installer-zip: tag-release: cd installer && ./tag_release.sh +# Generate the OpenAPI Schema for the app +openapi: + python scripts/generate_openapi_schema.py diff --git a/docs/help/FAQ.md b/docs/help/FAQ.md index 4c297f442a..25880f7cd2 100644 --- a/docs/help/FAQ.md +++ b/docs/help/FAQ.md @@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file. Check the [configuration docs] for more detail about the settings and how to specify them. +## `ModuleNotFoundError: No module named 'controlnet_aux'` + +`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control. + +If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed: + +- Run the Invoke launcher +- Choose the developer console option +- Run this command: `pip cache remove controlnet_aux` +- Close the terminal window +- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location + ## Out of Memory Issues The models are large, VRAM is expensive, and you may find yourself diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index b7da548377..e69d95af71 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -3,9 +3,7 @@ import logging import mimetypes import socket from contextlib import asynccontextmanager -from inspect import signature from pathlib import Path -from typing import Any import torch import uvicorn @@ -13,11 +11,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware -from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available # for PyCharm: @@ -25,10 +21,8 @@ from torch.backends.mps import is_available as is_mps_available import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles -from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config -from invokeai.app.services.events.events_common import EventBase -from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.util.custom_openapi import get_openapi_func from invokeai.backend.util.devices import TorchDevice from ..backend.util.logging import InvokeAILogger @@ -45,11 +39,6 @@ from .api.routers import ( workflows, ) from .api.sockets import SocketIO -from .invocations.baseinvocation import ( - BaseInvocation, - UIConfigBase, -) -from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra app_config = get_config() @@ -119,84 +108,7 @@ app.include_router(app_info.app_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api") - -# Build a custom OpenAPI to include all outputs -# TODO: can outputs be included on metadata of invocation schemas somehow? -def custom_openapi() -> dict[str, Any]: - if app.openapi_schema: - return app.openapi_schema - openapi_schema = get_openapi( - title=app.title, - description="An API for invoking AI image operations", - version="1.0.0", - routes=app.routes, - separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ - ) - - # Add all outputs - all_invocations = BaseInvocation.get_invocations() - output_types = set() - output_type_titles = {} - for invoker in all_invocations: - output_type = signature(invoker.invoke).return_annotation - output_types.add(output_type) - - output_schemas = models_json_schema( - models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}" - ) - for schema_key, output_schema in output_schemas[1]["$defs"].items(): - # TODO: note that we assume the schema_key here is the TYPE.__name__ - # This could break in some cases, figure out a better way to do it - output_type_titles[schema_key] = output_schema["title"] - openapi_schema["components"]["schemas"][schema_key] = output_schema - openapi_schema["components"]["schemas"][schema_key]["class"] = "output" - - # Some models don't end up in the schemas as standalone definitions - additional_schemas = models_json_schema( - [ - (UIConfigBase, "serialization"), - (InputFieldJSONSchemaExtra, "serialization"), - (OutputFieldJSONSchemaExtra, "serialization"), - (ModelIdentifierField, "serialization"), - (ProgressImage, "serialization"), - ], - ref_template="#/components/schemas/{model}", - ) - for schema_key, schema_json in additional_schemas[1]["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema_json - - openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { - "type": "object", - "properties": {}, - "required": [], - } - - # Add a reference to the output type to additionalProperties of the invoker schema - for invoker in all_invocations: - invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute - output_type = signature(obj=invoker.invoke).return_annotation - output_type_title = output_type_titles[output_type.__name__] - invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] - outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} - invoker_schema["output"] = outputs_ref - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref - openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) - invoker_schema["class"] = "invocation" - - # Add all event schemas - for event in sorted(EventBase.get_events(), key=lambda e: e.__name__): - json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") - if "$defs" in json_schema: - for schema_key, schema in json_schema["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = schema - del json_schema["$defs"] - openapi_schema["components"]["schemas"][event.__name__] = json_schema - - app.openapi_schema = openapi_schema - return app.openapi_schema - - -app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment +app.openapi = get_openapi_func(app) @app.get("/docs", include_in_schema=False) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 40c7b41cae..1d169f0a82 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel): _output_classes: ClassVar[set[BaseInvocationOutput]] = set() _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None + _typeadapter_needs_update: ClassVar[bool] = False @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: """Registers an invocation output.""" cls._output_classes.add(output) + cls._typeadapter_needs_update = True @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: @@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel): @classmethod def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation output types.""" - if not cls._typeadapter: - InvocationOutputsUnion = TypeAliasType( - "InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] + if not cls._typeadapter or cls._typeadapter_needs_update: + AnyInvocationOutput = TypeAliasType( + "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] ) - cls._typeadapter = TypeAdapter(InvocationOutputsUnion) + cls._typeadapter = TypeAdapter(AnyInvocationOutput) + cls._typeadapter_needs_update = False return cls._typeadapter @classmethod @@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel): return (i.get_type() for i in BaseInvocationOutput.get_outputs()) @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None: """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] + schema["class"] = "output" schema["required"].extend(["type"]) @classmethod @@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel): _invocation_classes: ClassVar[set[BaseInvocation]] = set() _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None + _typeadapter_needs_update: ClassVar[bool] = False @classmethod def get_type(cls) -> str: @@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel): def register_invocation(cls, invocation: BaseInvocation) -> None: """Registers an invocation.""" cls._invocation_classes.add(invocation) + cls._typeadapter_needs_update = True @classmethod def get_typeadapter(cls) -> TypeAdapter[Any]: """Gets a pydantc TypeAdapter for the union of all invocation types.""" - if not cls._typeadapter: - InvocationsUnion = TypeAliasType( - "InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] + if not cls._typeadapter or cls._typeadapter_needs_update: + AnyInvocation = TypeAliasType( + "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] ) - cls._typeadapter = TypeAdapter(InvocationsUnion) + cls._typeadapter = TypeAdapter(AnyInvocation) + cls._typeadapter_needs_update = False return cls._typeadapter @classmethod @@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel): return signature(cls.invoke).return_annotation @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None: """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) if uiconfig is not None: @@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel): schema["version"] = uiconfig.version if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] + schema["class"] = "invocation" schema["required"].extend(["type", "id"]) @abstractmethod @@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel): protected_namespaces=(), validate_assignment=True, json_schema_extra=json_schema_extra, - json_schema_serialization_defaults_required=True, + json_schema_serialization_defaults_required=False, coerce_numbers_to_str=True, ) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 7d3d489bf5..0adcaa2ab1 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -1,11 +1,10 @@ from math import floor -from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar from fastapi_events.handlers.local import local_handler from fastapi_events.registry.payload_schema import registry as payload_schema -from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny +from pydantic import BaseModel, ConfigDict, Field -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_queue.session_queue_common import ( QUEUE_ITEM_STATUS, @@ -14,6 +13,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItem, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -33,6 +33,7 @@ class EventBase(BaseModel): A timestamp is automatically added to the event when it is created. """ + __event_name__: ClassVar[str] timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) model_config = ConfigDict(json_schema_serialization_defaults_required=True) @@ -97,7 +98,7 @@ class InvocationEventBase(QueueItemEventBase): item_id: int = Field(description="The ID of the queue item") batch_id: str = Field(description="The ID of the queue batch") session_id: str = Field(description="The ID of the session (aka graph execution state)") - invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") + invocation: AnyInvocation = Field(description="The ID of the invocation") invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") @@ -108,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase): __event_name__ = "invocation_started" @classmethod - def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent": return cls( queue_id=queue_item.queue_id, item_id=queue_item.item_id, @@ -135,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase): def build( cls, queue_item: SessionQueueItem, - invocation: BaseInvocation, + invocation: AnyInvocation, intermediate_state: PipelineIntermediateState, progress_image: ProgressImage, ) -> "InvocationDenoiseProgressEvent": @@ -173,11 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase): __event_name__ = "invocation_complete" - result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") + result: AnyInvocationOutput = Field(description="The result of the invocation") @classmethod def build( - cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput ) -> "InvocationCompleteEvent": return cls( queue_id=queue_item.queue_id, @@ -206,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase): def build( cls, queue_item: SessionQueueItem, - invocation: BaseInvocation, + invocation: AnyInvocation, error_type: str, error_message: str, error_traceback: str, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 8508d2484c..d745e73823 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,18 +2,19 @@ import copy import itertools -from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import ( BaseModel, + GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationError, field_validator, ) from pydantic.fields import Field from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema +from pydantic_core import core_schema # Importing * is bad karma but needed here for node detection from invokeai.app.invocations import * # noqa: F401 F403 @@ -277,73 +278,58 @@ class CollectInvocation(BaseInvocation): return CollectInvocationOutput(collection=copy.copy(self.collection)) +class AnyInvocation(BaseInvocation): + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def validate_invocation(v: Any) -> "AnyInvocation": + return BaseInvocation.get_typeadapter().validate_python(v) + + return core_schema.no_info_plain_validator_function(validate_invocation) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + # Nodes are too powerful, we have to make our own OpenAPI schema manually + # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually + oneOf: list[dict[str, str]] = [] + names = [i.__name__ for i in BaseInvocation.get_invocations()] + for name in sorted(names): + oneOf.append({"$ref": f"#/components/schemas/{name}"}) + return {"oneOf": oneOf} + + +class AnyInvocationOutput(BaseInvocationOutput): + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler): + def validate_invocation_output(v: Any) -> "AnyInvocationOutput": + return BaseInvocationOutput.get_typeadapter().validate_python(v) + + return core_schema.no_info_plain_validator_function(validate_invocation_output) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + # Nodes are too powerful, we have to make our own OpenAPI schema manually + # No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually + + oneOf: list[dict[str, str]] = [] + names = [i.__name__ for i in BaseInvocationOutput.get_outputs()] + for name in sorted(names): + oneOf.append({"$ref": f"#/components/schemas/{name}"}) + return {"oneOf": oneOf} + + class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me - nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) + nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) - @field_validator("nodes", mode="plain") - @classmethod - def validate_nodes(cls, v: dict[str, Any]): - """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" - - # Invocations register themselves as their python modules are executed. The union of all invocations is - # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. - # - # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If - # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing - # invocations will cause a graph to fail if they are used. - # - # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the - # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. - # - # This same pattern is used in `GraphExecutionState`. - - nodes: dict[str, BaseInvocation] = {} - typeadapter = BaseInvocation.get_typeadapter() - for node_id, node in v.items(): - nodes[node_id] = typeadapter.validate_python(node) - return nodes - - @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for - # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to - # the generated schema as options for the `nodes` field. - # - # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and - # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as - # expected. - # - # You might be tempted to do something like this: - # - # ```py - # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) - # delattr(cloned_model, "validate_nodes") - # cloned_model.model_rebuild(force=True) - # json_schema = handler(cloned_model.__pydantic_core_schema__) - # ``` - # - # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts - # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. - # - # This same pattern is used in `GraphExecutionState`. - - class Graph(BaseModel): - id: Optional[str] = Field(default=None, description="The id of this graph") - nodes: dict[ - str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] - ] = Field(description="The nodes in this graph") - edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") - - json_schema = handler(Graph.__pydantic_core_schema__) - json_schema = handler.resolve_ref_schema(json_schema) - return json_schema - def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -774,7 +760,7 @@ class GraphExecutionState(BaseModel): ) # The results of executed nodes - results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) + results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) @@ -791,52 +777,12 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) - @field_validator("results", mode="plain") - @classmethod - def validate_results(cls, v: dict[str, BaseInvocationOutput]): - """Validates the results in the GES by retrieving a union of all output types and validating each result.""" - - # See the comment in `Graph.validate_nodes` for an explanation of this logic. - results: dict[str, BaseInvocationOutput] = {} - typeadapter = BaseInvocationOutput.get_typeadapter() - for result_id, result in v.items(): - results[result_id] = typeadapter.validate_python(result) - return results - @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" v.validate_self() return v - @classmethod - def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. - class GraphExecutionState(BaseModel): - """Tracks the state of a graph execution""" - - id: str = Field(description="The id of the execution state") - graph: Graph = Field(description="The graph being executed") - execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") - executed: set[str] = Field(description="The set of node ids that have been executed") - executed_history: list[str] = Field( - description="The list of node ids that have been executed, in order of execution" - ) - results: dict[ - str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] - ] = Field(description="The results of node executions") - errors: dict[str, str] = Field(description="Errors raised when executing nodes") - prepared_source_mapping: dict[str, str] = Field( - description="The map of prepared nodes to original graph nodes" - ) - source_prepared_mapping: dict[str, set[str]] = Field( - description="The map of original graph nodes to prepared nodes" - ) - - json_schema = handler(GraphExecutionState.__pydantic_core_schema__) - json_schema = handler.resolve_ref_schema(json_schema) - return json_schema - def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py new file mode 100644 index 0000000000..50259c12cc --- /dev/null +++ b/invokeai/app/util/custom_openapi.py @@ -0,0 +1,116 @@ +from typing import Any, Callable, Optional + +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from pydantic.json_schema import models_json_schema + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase +from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.services.events.events_common import EventBase +from invokeai.app.services.session_processor.session_processor_common import ProgressImage + + +def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None: + """Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema + for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and + component_schema.""" + + defs = component_schema.pop("$defs", {}) + for schema_key, json_schema in defs.items(): + if schema_key in openapi_schema["components"]["schemas"]: + continue + openapi_schema["components"]["schemas"][schema_key] = json_schema + + +def get_openapi_func( + app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None +) -> Callable[[], dict[str, Any]]: + """Gets the OpenAPI schema generator function. + + Args: + app (FastAPI): The FastAPI app to generate the schema for. + post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the + generated schema before returning it. Defaults to None. + + Returns: + Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is + cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour + matches FastAPI's default schema generation caching. + """ + + def openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + + openapi_schema = get_openapi( + title=app.title, + description="An API for invoking AI image operations", + version="1.0.0", + routes=app.routes, + separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/ + ) + + # We'll create a map of invocation type to output schema to make some types simpler on the client. + invocation_output_map_properties: dict[str, Any] = {} + invocation_output_map_required: list[str] = [] + + # We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly. + for output in BaseInvocationOutput.get_outputs(): + json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") + move_defs_to_top_level(openapi_schema, json_schema) + openapi_schema["components"]["schemas"][output.__name__] = json_schema + + # Technically, invocations are added to the schema by pydantic, but we still need to manually set their output + # property, so we'll just do it all manually. + for invocation in BaseInvocation.get_invocations(): + json_schema = invocation.model_json_schema( + mode="serialization", ref_template="#/components/schemas/{model}" + ) + move_defs_to_top_level(openapi_schema, json_schema) + output_title = invocation.get_output_annotation().__name__ + outputs_ref = {"$ref": f"#/components/schemas/{output_title}"} + json_schema["output"] = outputs_ref + openapi_schema["components"]["schemas"][invocation.__name__] = json_schema + + # Add this invocation and its output to the output map + invocation_type = invocation.get_type() + invocation_output_map_properties[invocation_type] = json_schema["output"] + invocation_output_map_required.append(invocation_type) + + # Add the output map to the schema + openapi_schema["components"]["schemas"]["InvocationOutputMap"] = { + "type": "object", + "properties": invocation_output_map_properties, + "required": invocation_output_map_required, + } + + # Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API. + # We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get + # a schema. This has something to do with schema refs - not totally clear. For whatever reason, using + # `models_json_schema` seems to work fine. + additional_models = [ + *EventBase.get_events(), + UIConfigBase, + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, + ModelIdentifierField, + ProgressImage, + ] + + additional_schemas = models_json_schema( + [(m, "serialization") for m in additional_models], + ref_template="#/components/schemas/{model}", + ) + # additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema + move_defs_to_top_level(openapi_schema, additional_schemas[1]) + + if post_transform is not None: + openapi_schema = post_transform(openapi_schema) + + openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items())) + + app.openapi_schema = openapi_schema + return app.openapi_schema + + return openapi diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index e3cb7c8fff..57f5fcd657 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -53,5 +53,5 @@ class ModelLocker(ModelLockerBase): """Call upon exit from context.""" self._cache_entry.unlock() if not self._cache.lazy_offloading: - self._cache.offload_unlocked_models(self._cache_entry.size) + self._cache.offload_unlocked_models(0) self._cache.print_cuda_stats() diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index f7390979bb..98104f769e 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -1,7 +1,7 @@ """Textual Inversion wrapper class.""" from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch from compel.embeddings_provider import BaseTextualInversionManager @@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel): return result -# no type hints for BaseTextualInversionManager? -class TextualInversionManager(BaseTextualInversionManager): # type: ignore - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer +class TextualInversionManager(BaseTextualInversionManager): + """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} + self.pad_tokens: dict[int, list[int]] = {} self.tokenizer = tokenizer def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + """Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens. + + For example, suppose we have a `` TI with 4 vectors that was added to the tokenizer with the following + mapping of tokens to token_ids: + ``` + : 49408 + : 49409 + : 49410 + : 49411 + ``` + `self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`. + This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`. + """ + # Short circuit if there are no pad tokens to save a little time. if len(self.pad_tokens) == 0: return token_ids + # This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify + # this assumption here. if token_ids[0] == self.tokenizer.bos_token_id: raise ValueError("token_ids must not start with bos_token_id") if token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("token_ids must not end with eos_token_id") - new_token_ids = [] + # Expand any TI tokens to their corresponding pad tokens. + new_token_ids: list[int] = [] for token_id in token_ids: new_token_ids.append(token_id) if token_id in self.pad_tokens: new_token_ids.extend(self.pad_tokens[token_id]) - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + # Do not exceed the max model input size. The -2 here is compensating for + # compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens. + max_length = self.tokenizer.model_max_length - 2 if len(new_token_ids) > max_length: + # HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard + # tokens. Token expansion should happen in a way that is compatible with compel's default handling of long + # prompts. new_token_ids = new_token_ids[0:max_length] return new_token_ids diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index f7a91ef756..306151984a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -148,6 +148,8 @@ "viewingDesc": "Review images in a large gallery view", "editing": "Editing", "editingDesc": "Edit on the Control Layers canvas", + "comparing": "Comparing", + "comparingDesc": "Comparing two images", "enabled": "Enabled", "disabled": "Disabled" }, @@ -375,7 +377,23 @@ "bulkDownloadRequestFailed": "Problem Preparing Download", "bulkDownloadFailed": "Download Failed", "problemDeletingImages": "Problem Deleting Images", - "problemDeletingImagesDesc": "One or more images could not be deleted" + "problemDeletingImagesDesc": "One or more images could not be deleted", + "viewerImage": "Viewer Image", + "compareImage": "Compare Image", + "openInViewer": "Open in Viewer", + "selectForCompare": "Select for Compare", + "selectAnImageToCompare": "Select an Image to Compare", + "slider": "Slider", + "sideBySide": "Side-by-Side", + "hover": "Hover", + "swapImages": "Swap Images", + "compareOptions": "Comparison Options", + "stretchToFit": "Stretch to Fit", + "exitCompare": "Exit Compare", + "compareHelp1": "Hold Alt while clicking a gallery image or using the arrow keys to change the compare image.", + "compareHelp2": "Press M to cycle through comparison modes.", + "compareHelp3": "Press C to swap the compared images.", + "compareHelp4": "Press Z or Esc to exit." }, "hotkeys": { "searchHotkeys": "Search Hotkeys", diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index 0b4ca90933..aa3a24209c 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -19,6 +19,13 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) { return extendTheme({ ..._theme, direction, + shadows: { + ..._theme.shadows, + selectedForCompare: + '0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)', + hoverSelectedForCompare: + '0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)', + }, }); }, [direction]); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts index 581146c25c..ba04947a2d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -13,7 +13,6 @@ import { isControlAdapterLayer, } from 'features/controlLayers/store/controlLayersSlice'; import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; -import { isImageOutput } from 'features/nodes/types/common'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { isEqual } from 'lodash-es'; @@ -139,7 +138,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni // We still have to check the output type assert( - isImageOutput(invocationCompleteAction.payload.data.result), + invocationCompleteAction.payload.data.result.type === 'image_output', `Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}` ); const { image_name } = invocationCompleteAction.payload.data.result.image; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 1e485b31d5..574dad00eb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -9,7 +9,6 @@ import { selectControlAdapterById, } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; -import { isImageOutput } from 'features/nodes/types/common'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -74,7 +73,7 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL ); // We still have to check the output type - if (isImageOutput(invocationCompleteAction.payload.data.result)) { + if (invocationCompleteAction.payload.data.result.type === 'image_output') { const { image_name } = invocationCompleteAction.payload.data.result.image; // Wait for the ImageDTO to be received diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts index 67c6d076ee..43f9355125 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts @@ -1,7 +1,7 @@ import { createAction } from '@reduxjs/toolkit'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors'; -import { selectionChanged } from 'features/gallery/store/gallerySlice'; +import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; import { imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import { imagesSelectors } from 'services/api/util'; @@ -11,6 +11,7 @@ export const galleryImageClicked = createAction<{ shiftKey: boolean; ctrlKey: boolean; metaKey: boolean; + altKey: boolean; }>('gallery/imageClicked'); /** @@ -28,7 +29,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen startAppListening({ actionCreator: galleryImageClicked, effect: async (action, { dispatch, getState }) => { - const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload; + const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload; const state = getState(); const queryArgs = selectListImagesQueryArgs(state); const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state); @@ -41,7 +42,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen const imageDTOs = imagesSelectors.selectAll(listImagesData); const selection = state.gallery.selection; - if (shiftKey) { + if (altKey) { + if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) { + dispatch(imageToCompareChanged(null)); + } else { + dispatch(imageToCompareChanged(imageDTO)); + } + } else if (shiftKey) { const rangeEndImageName = imageDTO.image_name; const lastSelectedImage = selection[selection.length - 1]?.image_name; const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index 9bc9635299..7cb0703af8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -14,7 +14,8 @@ import { rgLayerIPAdapterImageChanged, } from 'features/controlLayers/store/controlLayersSlice'; import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { isValidDrop } from 'features/dnd/util/isValidDrop'; +import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { imagesApi } from 'services/api/endpoints/images'; @@ -30,6 +31,9 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) => effect: async (action, { dispatch, getState }) => { const log = logger('dnd'); const { activeData, overData } = action.payload; + if (!isValidDrop(overData, activeData)) { + return; + } if (activeData.payloadType === 'IMAGE_DTO') { log.debug({ activeData, overData }, 'Image dropped'); @@ -50,6 +54,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) => activeData.payload.imageDTO ) { dispatch(imageSelected(activeData.payload.imageDTO)); + dispatch(isImageViewerOpenChanged(true)); return; } @@ -182,24 +187,18 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) => } /** - * TODO - * Image selection dropped on node image collection field + * Image selected for compare */ - // if ( - // overData.actionType === 'SET_MULTI_NODES_IMAGE' && - // activeData.payloadType === 'IMAGE_DTO' && - // activeData.payload.imageDTO - // ) { - // const { fieldName, nodeId } = overData.context; - // dispatch( - // fieldValueChanged({ - // nodeId, - // fieldName, - // value: [activeData.payload.imageDTO], - // }) - // ); - // return; - // } + if ( + overData.actionType === 'SELECT_FOR_COMPARE' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + const { imageDTO } = activeData.payload; + dispatch(imageToCompareChanged(imageDTO)); + dispatch(isImageViewerOpenChanged(true)); + return; + } /** * Image dropped on user board diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 1a04f9493a..2841493ca6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -11,7 +11,6 @@ import { } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; -import { isImageOutput } from 'features/nodes/types/common'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; @@ -33,7 +32,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi const { result, invocation_source_id } = data; // This complete event has an associated image output - if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) { + if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) { const { image_name } = data.result.image; const { canvas, gallery } = getState(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index e6fc5a526a..2c0caa0ec9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { parseify } from 'common/util/serialize'; import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions'; import { $templates } from 'features/nodes/store/nodesSlice'; -import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { $needsFit } from 'features/nodes/store/reactFlowInstance'; import type { Templates } from 'features/nodes/store/types'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow'; @@ -65,9 +65,7 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList }); } - requestAnimationFrame(() => { - $flow.get()?.fitView(); - }); + $needsFit.set(true); } catch (e) { if (e instanceof WorkflowVersionError) { // The workflow version was not recognized in the valid list of versions diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 2712334e1e..f16aa3d4b4 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -35,6 +35,7 @@ type IAIDndImageProps = FlexProps & { draggableData?: TypesafeDraggableData; dropLabel?: ReactNode; isSelected?: boolean; + isSelectedForCompare?: boolean; thumbnail?: boolean; noContentFallback?: ReactElement; useThumbailFallback?: boolean; @@ -61,6 +62,7 @@ const IAIDndImage = (props: IAIDndImageProps) => { draggableData, dropLabel, isSelected = false, + isSelectedForCompare = false, thumbnail = false, noContentFallback = defaultNoContentFallback, uploadElement = defaultUploadElement, @@ -165,7 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => { data-testid={dataTestId} /> {withMetadataOverlay && } - + )} {!imageDTO && !isUploadDisabled && ( diff --git a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx index 258a6e9004..ef331c4377 100644 --- a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx @@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => { pointerEvents={active ? 'auto' : 'none'} > - {isValidDrop(data, active) && } + {isValidDrop(data, active?.data.current) && } ); diff --git a/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx b/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx index eb50a6b9d4..3e2ecca4ae 100644 --- a/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/SelectionOverlay.tsx @@ -3,10 +3,17 @@ import { memo, useMemo } from 'react'; type Props = { isSelected: boolean; + isSelectedForCompare: boolean; isHovered: boolean; }; -const SelectionOverlay = ({ isSelected, isHovered }: Props) => { +const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => { const shadow = useMemo(() => { + if (isSelectedForCompare && isHovered) { + return 'hoverSelectedForCompare'; + } + if (isSelectedForCompare && !isHovered) { + return 'selectedForCompare'; + } if (isSelected && isHovered) { return 'hoverSelected'; } @@ -17,7 +24,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => { return 'hoverUnselected'; } return undefined; - }, [isHovered, isSelected]); + }, [isHovered, isSelected, isSelectedForCompare]); return ( { bottom={0} insetInlineStart={0} borderRadius="base" - opacity={isSelected ? 1 : 0.7} + opacity={isSelected || isSelectedForCompare ? 1 : 0.7} transitionProperty="common" transitionDuration="0.1s" pointerEvents="none" diff --git a/invokeai/frontend/web/src/common/hooks/useBoolean.ts b/invokeai/frontend/web/src/common/hooks/useBoolean.ts new file mode 100644 index 0000000000..123e48cd75 --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/useBoolean.ts @@ -0,0 +1,21 @@ +import { useCallback, useMemo, useState } from 'react'; + +export const useBoolean = (initialValue: boolean) => { + const [isTrue, set] = useState(initialValue); + const setTrue = useCallback(() => set(true), []); + const setFalse = useCallback(() => set(false), []); + const toggle = useCallback(() => set((v) => !v), []); + + const api = useMemo( + () => ({ + isTrue, + set, + setTrue, + setFalse, + toggle, + }), + [isTrue, set, setTrue, setFalse, toggle] + ); + + return api; +}; diff --git a/invokeai/frontend/web/src/common/util/stopPropagation.ts b/invokeai/frontend/web/src/common/util/stopPropagation.ts index b3481b7c0e..0c6a1fc507 100644 --- a/invokeai/frontend/web/src/common/util/stopPropagation.ts +++ b/invokeai/frontend/web/src/common/util/stopPropagation.ts @@ -1,3 +1,7 @@ export const stopPropagation = (e: React.MouseEvent) => { e.stopPropagation(); }; + +export const preventDefault = (e: React.MouseEvent) => { + e.preventDefault(); +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts index 708e089008..13435bdb7c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/util/controlAdapters.ts @@ -1,7 +1,13 @@ import { deepClone } from 'common/util/deepClone'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { merge, omit } from 'lodash-es'; -import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types'; +import type { + AnyInvocation, + BaseModelType, + ControlNetModelConfig, + ImageDTO, + T2IAdapterModelConfig, +} from 'services/api/types'; import { z } from 'zod'; const zId = z.string().min(1); @@ -147,7 +153,7 @@ const zBeginEndStepPct = z const zControlAdapterBase = z.object({ id: zId, - weight: z.number().gte(0).lte(1), + weight: z.number().gte(-1).lte(2), image: zImageWithDims.nullable(), processedImage: zImageWithDims.nullable(), processorConfig: zProcessorConfig.nullable(), @@ -183,7 +189,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar export const zIPAdapterConfigV2 = z.object({ id: zId, type: z.literal('ip_adapter'), - weight: z.number().gte(0).lte(1), + weight: z.number().gte(-1).lte(2), method: zIPMethodV2, image: zImageWithDims.nullable(), model: zModelIdentifierField.nullable(), @@ -216,10 +222,7 @@ type ProcessorData = { labelTKey: string; descriptionTKey: string; buildDefaults(baseModel?: BaseModelType): Extract; - buildNode( - image: ImageWithDims, - config: Extract - ): Extract; + buildNode(image: ImageWithDims, config: Extract): Extract; }; const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height); diff --git a/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts b/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts index 25ac30387b..79933e6b00 100644 --- a/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts +++ b/invokeai/frontend/web/src/features/controlLayers/util/renderers.ts @@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)'; const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)'; const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)'; // This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL -const STAGE_BG_DATAURL = +export const STAGE_BG_DATAURL = 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII='; const mapId = (object: { id: string }) => object.id; diff --git a/invokeai/frontend/web/src/features/dnd/types/index.ts b/invokeai/frontend/web/src/features/dnd/types/index.ts index 4d09c759eb..6fcf18421e 100644 --- a/invokeai/frontend/web/src/features/dnd/types/index.ts +++ b/invokeai/frontend/web/src/features/dnd/types/index.ts @@ -18,7 +18,7 @@ type BaseDropData = { id: string; }; -type CurrentImageDropData = BaseDropData & { +export type CurrentImageDropData = BaseDropData & { actionType: 'SET_CURRENT_IMAGE'; }; @@ -79,6 +79,14 @@ export type RemoveFromBoardDropData = BaseDropData & { actionType: 'REMOVE_FROM_BOARD'; }; +export type SelectForCompareDropData = BaseDropData & { + actionType: 'SELECT_FOR_COMPARE'; + context: { + firstImageName?: string | null; + secondImageName?: string | null; + }; +}; + export type TypesafeDroppableData = | CurrentImageDropData | ControlAdapterDropData @@ -89,7 +97,8 @@ export type TypesafeDroppableData = | CALayerImageDropData | IPALayerImageDropData | RGLayerIPAdapterImageDropData - | IILayerImageDropData; + | IILayerImageDropData + | SelectForCompareDropData; type BaseDragData = { id: string; @@ -134,7 +143,7 @@ export type UseDraggableTypesafeReturnValue = Omit { +interface TypesafeActive extends Omit { data: React.MutableRefObject; } diff --git a/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts index b701c72947..6dec862345 100644 --- a/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts +++ b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts @@ -1,14 +1,14 @@ -import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types'; +import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; -export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => { - if (!overData || !active?.data.current) { +export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => { + if (!overData || !activeData) { return false; } const { actionType } = overData; - const { payloadType } = active.data.current; + const { payloadType } = activeData; - if (overData.id === active.data.current.id) { + if (overData.id === activeData.id) { return false; } @@ -29,6 +29,8 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: return payloadType === 'IMAGE_DTO'; case 'SET_NODES_IMAGE': return payloadType === 'IMAGE_DTO'; + case 'SELECT_FOR_COMPARE': + return payloadType === 'IMAGE_DTO'; case 'ADD_TO_BOARD': { // If the board is the same, don't allow the drop @@ -40,7 +42,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: // Check if the image's board is the board we are dragging onto if (payloadType === 'IMAGE_DTO') { - const { imageDTO } = active.data.current.payload; + const { imageDTO } = activeData.payload; const currentBoard = imageDTO.board_id ?? 'none'; const destinationBoard = overData.context.boardId; @@ -49,7 +51,7 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: if (payloadType === 'GALLERY_SELECTION') { // Assume all images are on the same board - this is true for the moment - const currentBoard = active.data.current.payload.boardId; + const currentBoard = activeData.payload.boardId; const destinationBoard = overData.context.boardId; return currentBoard !== destinationBoard; } @@ -67,14 +69,14 @@ export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: // Check if the image's board is the board we are dragging onto if (payloadType === 'IMAGE_DTO') { - const { imageDTO } = active.data.current.payload; + const { imageDTO } = activeData.payload; const currentBoard = imageDTO.board_id ?? 'none'; return currentBoard !== 'none'; } if (payloadType === 'GALLERY_SELECTION') { - const currentBoard = active.data.current.payload.boardId; + const currentBoard = activeData.payload.boardId; return currentBoard !== 'none'; } diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 0509305192..f8c4f5ebcf 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps )} {isSelectedForAutoAdd && } - + { > {boardName} - + {t('unifiedCanvas.move')}} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index b3119aa8fa..31df113115 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -10,6 +10,7 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions'; +import { imageToCompareChanged } from 'features/gallery/store/gallerySlice'; import { $templates } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -27,6 +28,7 @@ import { PiDownloadSimpleBold, PiFlowArrowBold, PiFoldersBold, + PiImagesBold, PiPlantBold, PiQuotesBold, PiShareFatBold, @@ -44,6 +46,7 @@ type SingleSelectionMenuItemsProps = { const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const { imageDTO } = props; const optimalDimension = useAppSelector(selectOptimalDimension); + const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name); const dispatch = useAppDispatch(); const { t } = useTranslation(); const isCanvasEnabled = useFeatureStatus('canvas'); @@ -117,6 +120,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { downloadImage(imageDTO.image_url, imageDTO.image_name); }, [downloadImage, imageDTO.image_name, imageDTO.image_url]); + const handleSelectImageForCompare = useCallback(() => { + dispatch(imageToCompareChanged(imageDTO)); + }, [dispatch, imageDTO]); + return ( <> }> @@ -130,6 +137,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { } onClickCapture={handleDownloadImage}> {t('parameters.downloadImage')} + } isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}> + {t('gallery.selectForCompare')} + : } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index 2c53599ba3..e5e216c97c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId'; import { useMultiselect } from 'features/gallery/hooks/useMultiselect'; import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView'; -import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; +import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice'; import type { MouseEvent } from 'react'; import { memo, useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -46,6 +46,7 @@ const GalleryImage = (props: HoverableImageProps) => { const { t } = useTranslation(); const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId); const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge); + const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName); const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO); const customStarUi = useStore($customStarUI); @@ -105,6 +106,7 @@ const GalleryImage = (props: HoverableImageProps) => { const onDoubleClick = useCallback(() => { dispatch(isImageViewerOpenChanged(true)); + dispatch(imageToCompareChanged(null)); }, [dispatch]); const handleMouseOut = useCallback(() => { @@ -152,6 +154,7 @@ const GalleryImage = (props: HoverableImageProps) => { imageDTO={imageDTO} draggableData={draggableData} isSelected={isSelected} + isSelectedForCompare={isSelectedForCompare} minSize={0} imageSx={imageSx} isDropDisabled={true} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CompareToolbar.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CompareToolbar.tsx new file mode 100644 index 0000000000..4f525bc670 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CompareToolbar.tsx @@ -0,0 +1,140 @@ +import { + Button, + ButtonGroup, + Flex, + Icon, + IconButton, + Kbd, + ListItem, + Tooltip, + UnorderedList, +} from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + comparedImagesSwapped, + comparisonFitChanged, + comparisonModeChanged, + comparisonModeCycled, + imageToCompareChanged, +} from 'features/gallery/store/gallerySlice'; +import { memo, useCallback } from 'react'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { Trans, useTranslation } from 'react-i18next'; +import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi'; + +export const CompareToolbar = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode); + const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit); + const setComparisonModeSlider = useCallback(() => { + dispatch(comparisonModeChanged('slider')); + }, [dispatch]); + const setComparisonModeSideBySide = useCallback(() => { + dispatch(comparisonModeChanged('side-by-side')); + }, [dispatch]); + const setComparisonModeHover = useCallback(() => { + dispatch(comparisonModeChanged('hover')); + }, [dispatch]); + const swapImages = useCallback(() => { + dispatch(comparedImagesSwapped()); + }, [dispatch]); + useHotkeys('c', swapImages, [swapImages]); + const toggleComparisonFit = useCallback(() => { + dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain')); + }, [dispatch, comparisonFit]); + const exitCompare = useCallback(() => { + dispatch(imageToCompareChanged(null)); + }, [dispatch]); + useHotkeys('esc', exitCompare, [exitCompare]); + const nextMode = useCallback(() => { + dispatch(comparisonModeCycled()); + }, [dispatch]); + useHotkeys('m', nextMode, [nextMode]); + + return ( + + + + } + aria-label={`${t('gallery.swapImages')} (C)`} + tooltip={`${t('gallery.swapImages')} (C)`} + onClick={swapImages} + /> + {comparisonMode !== 'side-by-side' && ( + } + /> + )} + + + + + + + + + + + + }> + + + + + } + aria-label={`${t('gallery.exitCompare')} (Esc)`} + tooltip={`${t('gallery.exitCompare')} (Esc)`} + onClick={exitCompare} + /> + + + + ); +}); + +CompareToolbar.displayName = 'CompareToolbar'; + +const CompareHelp = () => { + return ( + + + }}> + + + }}> + + + }}> + + + }}> + + + ); +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index f40ecfca32..a812391992 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types'; +import type { TypesafeDraggableData } from 'features/dnd/types'; import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer'; import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; @@ -22,21 +22,7 @@ const selectLastSelectedImageName = createSelector( (lastSelectedImage) => lastSelectedImage?.image_name ); -type Props = { - isDragDisabled?: boolean; - isDropDisabled?: boolean; - withNextPrevButtons?: boolean; - withMetadata?: boolean; - alwaysShowProgress?: boolean; -}; - -const CurrentImagePreview = ({ - isDragDisabled = false, - isDropDisabled = false, - withNextPrevButtons = true, - withMetadata = true, - alwaysShowProgress = false, -}: Props) => { +const CurrentImagePreview = () => { const { t } = useTranslation(); const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails); const imageName = useAppSelector(selectLastSelectedImageName); @@ -55,14 +41,6 @@ const CurrentImagePreview = ({ } }, [imageDTO]); - const droppableData = useMemo( - () => ({ - id: 'current-image', - actionType: 'SET_CURRENT_IMAGE', - }), - [] - ); - // Show and hide the next/prev buttons on mouse move const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState(false); const timeoutId = useRef(0); @@ -86,30 +64,27 @@ const CurrentImagePreview = ({ justifyContent="center" position="relative" > - {hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? ( + {hasDenoiseProgress && shouldShowProgressInViewer ? ( ) : ( } dataTestId="image-preview" /> )} - {shouldShowImageDetails && imageDTO && withMetadata && ( + {shouldShowImageDetails && imageDTO && ( )} - {withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && ( + {shouldShowNextPrevButtons && imageDTO && ( { + const { t } = useTranslation(); + const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode); + const { firstImage, secondImage } = useAppSelector(selectComparisonImages); + + if (!firstImage || !secondImage) { + // Should rarely/never happen - we don't render this component unless we have images to compare + return ; + } + + if (comparisonMode === 'slider') { + return ; + } + + if (comparisonMode === 'side-by-side') { + return ( + + ); + } + + if (comparisonMode === 'hover') { + return ; + } +}); + +ImageComparison.displayName = 'ImageComparison'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonDroppable.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonDroppable.tsx new file mode 100644 index 0000000000..3678c920c0 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonDroppable.tsx @@ -0,0 +1,47 @@ +import { Flex } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import IAIDroppable from 'common/components/IAIDroppable'; +import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types'; +import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer'; +import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { selectComparisonImages } from './common'; + +const setCurrentImageDropData: CurrentImageDropData = { + id: 'current-image', + actionType: 'SET_CURRENT_IMAGE', +}; + +export const ImageComparisonDroppable = memo(() => { + const { t } = useTranslation(); + const imageViewer = useImageViewer(); + const { firstImage, secondImage } = useAppSelector(selectComparisonImages); + const selectForCompareDropData = useMemo( + () => ({ + id: 'image-comparison', + actionType: 'SELECT_FOR_COMPARE', + context: { + firstImageName: firstImage?.image_name, + secondImageName: secondImage?.image_name, + }, + }), + [firstImage?.image_name, secondImage?.image_name] + ); + + if (!imageViewer.isOpen) { + return ( + + + + ); + } + + return ( + + + + ); +}); + +ImageComparisonDroppable.displayName = 'ImageComparisonDroppable'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx new file mode 100644 index 0000000000..a02e94b547 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonHover.tsx @@ -0,0 +1,117 @@ +import { Box, Flex, Image } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useBoolean } from 'common/hooks/useBoolean'; +import { preventDefault } from 'common/util/stopPropagation'; +import type { Dimensions } from 'features/canvas/store/canvasTypes'; +import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers'; +import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; +import { memo, useMemo, useRef } from 'react'; + +import type { ComparisonProps } from './common'; +import { fitDimsToContainer, getSecondImageDims } from './common'; + +export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => { + const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit); + const imageContainerRef = useRef(null); + const mouseOver = useBoolean(false); + const fittedDims = useMemo( + () => fitDimsToContainer(containerDims, firstImage), + [containerDims, firstImage] + ); + const compareImageDims = useMemo( + () => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage), + [comparisonFit, fittedDims, firstImage, secondImage] + ); + return ( + + + + + + + + + + + + + + + + ); +}); + +ImageComparisonHover.displayName = 'ImageComparisonHover'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonLabel.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonLabel.tsx new file mode 100644 index 0000000000..a5a40dfc9c --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonLabel.tsx @@ -0,0 +1,33 @@ +import type { TextProps } from '@invoke-ai/ui-library'; +import { Text } from '@invoke-ai/ui-library'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { DROP_SHADOW } from './common'; + +type Props = TextProps & { + type: 'first' | 'second'; +}; + +export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => { + const { t } = useTranslation(); + return ( + + {type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')} + + ); +}); + +ImageComparisonLabel.displayName = 'ImageComparisonLabel'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSideBySide.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSideBySide.tsx new file mode 100644 index 0000000000..8bac2bb45d --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSideBySide.tsx @@ -0,0 +1,70 @@ +import { Flex, Image } from '@invoke-ai/ui-library'; +import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common'; +import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; +import ResizeHandle from 'features/ui/components/tabs/ResizeHandle'; +import { memo, useCallback, useRef } from 'react'; +import type { ImperativePanelGroupHandle } from 'react-resizable-panels'; +import { Panel, PanelGroup } from 'react-resizable-panels'; + +export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => { + const panelGroupRef = useRef(null); + const onDoubleClickHandle = useCallback(() => { + if (!panelGroupRef.current) { + return; + } + panelGroupRef.current.setLayout([50, 50]); + }, []); + + return ( + + + + + + + + + + + + + + + + + + + + + + + + + ); +}); + +ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx new file mode 100644 index 0000000000..8972af7d4f --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageComparisonSlider.tsx @@ -0,0 +1,215 @@ +import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { preventDefault } from 'common/util/stopPropagation'; +import type { Dimensions } from 'features/canvas/store/canvasTypes'; +import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers'; +import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel'; +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi'; + +import type { ComparisonProps } from './common'; +import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common'; + +const INITIAL_POS = '50%'; +const HANDLE_WIDTH = 2; +const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`; +const HANDLE_HITBOX = 20; +const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`; +const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`; +const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`; + +export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => { + const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit); + // How far the handle is from the left - this will be a CSS calculation that takes into account the handle width + const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX); + // How wide the first image is + const [width, setWidth] = useState(INITIAL_POS); + const handleRef = useRef(null); + // To manage aspect ratios, we need to know the size of the container + const imageContainerRef = useRef(null); + // To keep things smooth, we use RAF to update the handle position & gate it to 60fps + const rafRef = useRef(null); + const lastMoveTimeRef = useRef(0); + + const fittedDims = useMemo( + () => fitDimsToContainer(containerDims, firstImage), + [containerDims, firstImage] + ); + + const compareImageDims = useMemo( + () => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage), + [comparisonFit, fittedDims, firstImage, secondImage] + ); + + const updateHandlePos = useCallback((clientX: number) => { + if (!handleRef.current || !imageContainerRef.current) { + return; + } + lastMoveTimeRef.current = performance.now(); + const { x, width } = imageContainerRef.current.getBoundingClientRect(); + const rawHandlePos = ((clientX - x) * 100) / width; + const handleWidthPct = (HANDLE_WIDTH * 100) / width; + const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos)); + setWidth(`${newHandlePos}%`); + setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`); + }, []); + + const onMouseMove = useCallback( + (e: MouseEvent) => { + if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) { + rafRef.current = window.requestAnimationFrame(() => { + updateHandlePos(e.clientX); + rafRef.current = null; + }); + } + }, + [updateHandlePos] + ); + + const onMouseUp = useCallback(() => { + window.removeEventListener('mousemove', onMouseMove); + }, [onMouseMove]); + + const onMouseDown = useCallback( + (e: React.MouseEvent) => { + // Update the handle position immediately on click + updateHandlePos(e.clientX); + window.addEventListener('mouseup', onMouseUp, { once: true }); + window.addEventListener('mousemove', onMouseMove); + }, + [onMouseMove, onMouseUp, updateHandlePos] + ); + + useEffect( + () => () => { + if (rafRef.current !== null) { + cancelAnimationFrame(rafRef.current); + } + }, + [] + ); + + return ( + + + + + + + + + + + + + + + + + + + + + + ); +}); + +ImageComparisonSlider.displayName = 'ImageComparisonSlider'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewer.tsx index 7064e553dc..530431fc4c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewer.tsx @@ -1,36 +1,16 @@ -import { Flex } from '@invoke-ai/ui-library'; -import { useAppSelector } from 'app/store/storeHooks'; -import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton'; -import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton'; -import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer'; -import type { InvokeTabName } from 'features/ui/store/tabMap'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { memo, useMemo } from 'react'; -import { useHotkeys } from 'react-hotkeys-hook'; +import { Box, Flex } from '@invoke-ai/ui-library'; +import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar'; +import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview'; +import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison'; +import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar'; +import { memo } from 'react'; +import { useMeasure } from 'react-use'; -import CurrentImageButtons from './CurrentImageButtons'; -import CurrentImagePreview from './CurrentImagePreview'; -import { ViewerToggleMenu } from './ViewerToggleMenu'; - -const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows']; +import { useImageViewer } from './useImageViewer'; export const ImageViewer = memo(() => { - const { isOpen, onToggle, onClose } = useImageViewer(); - const activeTabName = useAppSelector(activeTabNameSelector); - const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]); - const shouldShowViewer = useMemo(() => { - if (!isViewerEnabled) { - return false; - } - return isOpen; - }, [isOpen, isViewerEnabled]); - - useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]); - useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]); - - if (!shouldShowViewer) { - return null; - } + const imageViewer = useImageViewer(); + const [containerRef, containerDims] = useMeasure(); return ( { rowGap={4} alignItems="center" justifyContent="center" - zIndex={10} // reactflow puts its minimap at 5, so we need to be above that > - - - - - - - - - - - - - - - - - + {imageViewer.isComparing && } + {!imageViewer.isComparing && } + + {!imageViewer.isComparing && } + {imageViewer.isComparing && } + ); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewerWorkflows.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewerWorkflows.tsx deleted file mode 100644 index fe09f11be6..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ImageViewerWorkflows.tsx +++ /dev/null @@ -1,45 +0,0 @@ -import { Flex } from '@invoke-ai/ui-library'; -import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton'; -import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton'; -import { memo } from 'react'; - -import CurrentImageButtons from './CurrentImageButtons'; -import CurrentImagePreview from './CurrentImagePreview'; - -export const ImageViewerWorkflows = memo(() => { - return ( - - - - - - - - - - - - - - - - - - ); -}); - -ImageViewerWorkflows.displayName = 'ImageViewerWorkflows'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ViewerToggleMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ViewerToggleMenu.tsx index 3552c28a5b..7dc13afb48 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ViewerToggleMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ViewerToggleMenu.tsx @@ -9,33 +9,35 @@ import { PopoverTrigger, Text, } from '@invoke-ai/ui-library'; +import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer'; +import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi'; -import { useImageViewer } from './useImageViewer'; - export const ViewerToggleMenu = () => { const { t } = useTranslation(); - const { isOpen, onClose, onOpen } = useImageViewer(); + const imageViewer = useImageViewer(); + useHotkeys('z', imageViewer.onToggle, [imageViewer]); + useHotkeys('esc', imageViewer.onClose, [imageViewer]); return ( - - + - -