From 35dd58e273c64059ce0d708db36e85c8670ac520 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Sat, 29 Jul 2023 12:59:56 +0530 Subject: [PATCH 01/33] chore: move PR template to `.github/` dir --- pull_request_template.md => .github/pull_request_template.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pull_request_template.md => .github/pull_request_template.md (100%) diff --git a/pull_request_template.md b/.github/pull_request_template.md similarity index 100% rename from pull_request_template.md rename to .github/pull_request_template.md From fd7b84241988a57d7b9dae24798cc0bb475dd92c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 1 Aug 2023 17:44:09 -0400 Subject: [PATCH 02/33] add execution stat reporting after each invocation --- invokeai/app/services/invocation_stats.py | 115 ++++++++++++++++++++++ invokeai/app/services/processor.py | 43 ++++---- 2 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 invokeai/app/services/invocation_stats.py diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py new file mode 100644 index 0000000000..8d41b60d49 --- /dev/null +++ b/invokeai/app/services/invocation_stats.py @@ -0,0 +1,115 @@ +# Copyright 2023 Lincoln D. Stein +"""Utility to collect execution time and GPU usage stats on invocations in flight""" + +""" +Usage: +statistics = InvocationStats() # keep track of performance metrics +... +with statistics.collect_stats(invocation, graph_execution_state): + outputs = invocation.invoke( + InvocationContext( + services=self.__invoker.services, + graph_execution_state_id=graph_execution_state.id, + ) + ) +... +statistics.log_stats() + +Typical output: +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Node Calls Seconds +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> main_model_loader 1 0.006s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> clip_skip 1 0.005s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> compel 2 0.351s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> rand_int 1 0.001s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> range_of_size 1 0.001s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> iterate 1 0.001s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> noise 1 0.002s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> t2l 1 3.117s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> l2i 1 0.377s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> TOTAL: 3.865s +[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Max VRAM used for execution: 3.12G. +[2023-08-01 17:34:44,586]::[InvokeAI]::INFO --> Current VRAM utilization 2.31G. +""" + +import time +from typing import Dict, List + +import torch + +from .graph import GraphExecutionState +from .invocation_queue import InvocationQueueItem +from ..invocations.baseinvocation import BaseInvocation + +import invokeai.backend.util.logging as logger + +class InvocationStats(): + """Accumulate performance information about a running graph. Collects time spent in each node, + as well as the maximum and current VRAM utilisation for CUDA systems""" + + def __init__(self): + self._stats: Dict[str, int] = {} + + class StatsContext(): + def __init__(self, invocation: BaseInvocation, collector): + self.invocation = invocation + self.collector = collector + self.start_time = 0 + + def __enter__(self): + self.start_time = time.time() + + def __exit__(self, *args): + self.collector.log_time(self.invocation.type, time.time() - self.start_time) + + def collect_stats(self, + invocation: BaseInvocation, + graph_execution_state: GraphExecutionState, + ) -> StatsContext: + """ + Return a context object that will capture the statistics. + :param invocation: BaseInvocation object from the current graph. + :param graph_execution_state: GraphExecutionState object from the current session. + """ + if len(graph_execution_state.executed)==0: # new graph is starting + self.reset_stats() + self._current_graph_state = graph_execution_state + sc = self.StatsContext(invocation, self) + return self.StatsContext(invocation, self) + + def reset_stats(self): + """Zero the statistics. Ordinarily called internally.""" + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + self._stats: Dict[str, List[int, float]] = {} + + + def log_time(self, invocation_type: str, time_used: float): + """ + Add timing information on execution of a node. Usually + used internally. + :param invocation_type: String literal type of the node + :param time_used: Floating point seconds used by node's exection + """ + if not self._stats.get(invocation_type): + self._stats[invocation_type] = [0, 0.0] + self._stats[invocation_type][0] += 1 + self._stats[invocation_type][1] += time_used + + def log_stats(self): + """ + Send the statistics to the system logger at the info level. + Stats will only be printed if when the execution of the graph + is complete. + """ + if self._current_graph_state.is_complete(): + logger.info('Node Calls Seconds') + for node_type, (calls, time_used) in self._stats.items(): + logger.info(f'{node_type:<20} {calls:>5} {time_used:4.3f}s') + + total_time = sum([ticks for _,ticks in self._stats.values()]) + logger.info(f'TOTAL: {total_time:4.3f}s') + if torch.cuda.is_available(): + logger.info('Max VRAM used for execution: '+'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9)) + logger.info('Current VRAM utilization '+'%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) + diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 50fe217e05..a43e2878ac 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -5,6 +5,7 @@ from threading import Event, Thread, BoundedSemaphore from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker +from .invocation_stats import InvocationStats from ..models.exceptions import CanceledException import invokeai.backend.util.logging as logger @@ -35,6 +36,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: self.__threadLimit.acquire() + statistics = InvocationStats() # keep track of performance metrics while not stop_event.is_set(): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() @@ -83,30 +85,32 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: - outputs = invocation.invoke( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, + with statistics.collect_stats(invocation, graph_execution_state): + outputs = invocation.invoke( + InvocationContext( + services=self.__invoker.services, + graph_execution_state_id=graph_execution_state.id, + ) ) - ) - # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled(graph_execution_state.id): - continue + # Check queue to see if this is canceled, and skip if so + if self.__invoker.services.queue.is_canceled(graph_execution_state.id): + continue - # Save outputs and history - graph_execution_state.complete(invocation.id, outputs) + # Save outputs and history + graph_execution_state.complete(invocation.id, outputs) - # Save the state changes - self.__invoker.services.graph_execution_manager.set(graph_execution_state) + # Save the state changes + self.__invoker.services.graph_execution_manager.set(graph_execution_state) - # Send complete event - self.__invoker.services.events.emit_invocation_complete( - graph_execution_state_id=graph_execution_state.id, - node=invocation.dict(), - source_node_id=source_node_id, - result=outputs.dict(), - ) + # Send complete event + self.__invoker.services.events.emit_invocation_complete( + graph_execution_state_id=graph_execution_state.id, + node=invocation.dict(), + source_node_id=source_node_id, + result=outputs.dict(), + ) + statistics.log_stats() except KeyboardInterrupt: pass @@ -161,3 +165,4 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor finally: self.__threadLimit.release() + From 8a4e5f73aa6e914494885438a560a7d6694b6ce2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 1 Aug 2023 19:39:42 -0400 Subject: [PATCH 03/33] reset stats on exception --- invokeai/app/services/invocation_stats.py | 38 +++++++++++------------ invokeai/app/services/processor.py | 4 +-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 8d41b60d49..24a5662647 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -43,14 +43,15 @@ from ..invocations.baseinvocation import BaseInvocation import invokeai.backend.util.logging as logger -class InvocationStats(): + +class InvocationStats: """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" def __init__(self): self._stats: Dict[str, int] = {} - - class StatsContext(): + + class StatsContext: def __init__(self, invocation: BaseInvocation, collector): self.invocation = invocation self.collector = collector @@ -61,17 +62,18 @@ class InvocationStats(): def __exit__(self, *args): self.collector.log_time(self.invocation.type, time.time() - self.start_time) - - def collect_stats(self, - invocation: BaseInvocation, - graph_execution_state: GraphExecutionState, - ) -> StatsContext: + + def collect_stats( + self, + invocation: BaseInvocation, + graph_execution_state: GraphExecutionState, + ) -> StatsContext: """ Return a context object that will capture the statistics. :param invocation: BaseInvocation object from the current graph. :param graph_execution_state: GraphExecutionState object from the current session. """ - if len(graph_execution_state.executed)==0: # new graph is starting + if len(graph_execution_state.executed) == 0: # new graph is starting self.reset_stats() self._current_graph_state = graph_execution_state sc = self.StatsContext(invocation, self) @@ -83,7 +85,6 @@ class InvocationStats(): torch.cuda.reset_peak_memory_stats() self._stats: Dict[str, List[int, float]] = {} - def log_time(self, invocation_type: str, time_used: float): """ Add timing information on execution of a node. Usually @@ -95,7 +96,7 @@ class InvocationStats(): self._stats[invocation_type] = [0, 0.0] self._stats[invocation_type][0] += 1 self._stats[invocation_type][1] += time_used - + def log_stats(self): """ Send the statistics to the system logger at the info level. @@ -103,13 +104,12 @@ class InvocationStats(): is complete. """ if self._current_graph_state.is_complete(): - logger.info('Node Calls Seconds') + logger.info("Node Calls Seconds") for node_type, (calls, time_used) in self._stats.items(): - logger.info(f'{node_type:<20} {calls:>5} {time_used:4.3f}s') - - total_time = sum([ticks for _,ticks in self._stats.values()]) - logger.info(f'TOTAL: {total_time:4.3f}s') + logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s") + + total_time = sum([ticks for _, ticks in self._stats.values()]) + logger.info(f"TOTAL: {total_time:4.3f}s") if torch.cuda.is_available(): - logger.info('Max VRAM used for execution: '+'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9)) - logger.info('Current VRAM utilization '+'%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) - + logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) + logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index a43e2878ac..e9511aa283 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -116,6 +116,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass except CanceledException: + statistics.reset_stats() pass except Exception as e: @@ -137,7 +138,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - + statistics.reset_stats() pass # Check queue to see if this is canceled, and skip if so @@ -165,4 +166,3 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor finally: self.__threadLimit.release() - From ed76250dbaa7ea27d0a0d06051337cace9f6a23e Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Wed, 2 Aug 2023 07:21:21 -0400 Subject: [PATCH 04/33] Stop checking for unet/model.onnx when a model_index.json is detected --- invokeai/backend/install/model_install_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index c0a7244367..ac032d4955 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -303,7 +303,7 @@ class ModelInstall(object): with TemporaryDirectory(dir=self.config.models_path) as staging: staging = Path(staging) - if "model_index.json" in files and "unet/model.onnx" not in files: + if "model_index.json" in files: location = self._download_hf_pipeline(repo_id, staging) # pipeline elif "unet/model.onnx" in files: location = self._download_hf_model(repo_id, files, staging) From 8fc75a71ee20dd96ceb59f9a4a904125d71e8ad6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 2 Aug 2023 18:10:52 -0400 Subject: [PATCH 05/33] integrate correctly into app API and add features - Create abstract base class InvocationStatsServiceBase - Store InvocationStatsService in the InvocationServices object - Collect and report stats on simultaneous graph execution independently for each graph id - Track VRAM usage for each node - Handle cancellations and other exceptions gracefully --- invokeai/app/api/dependencies.py | 3 +- invokeai/app/services/invocation_services.py | 3 + invokeai/app/services/invocation_stats.py | 222 ++++++++++++++----- invokeai/app/services/processor.py | 23 +- 4 files changed, 182 insertions(+), 69 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index d609ce3be2..b25009c8c9 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,7 +2,6 @@ from typing import Optional from logging import Logger -import os from invokeai.app.services.board_image_record_storage import ( SqliteBoardImageRecordStorage, ) @@ -30,6 +29,7 @@ from ..services.invoker import Invoker from ..services.processor import DefaultInvocationProcessor from ..services.sqlite import SqliteItemStorage from ..services.model_manager_service import ModelManagerService +from ..services.invocation_stats import InvocationStatsService from .events import FastAPIEventService @@ -128,6 +128,7 @@ class ApiDependencies: graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), configuration=config, + performance_statistics=InvocationStatsService(graph_execution_manager), logger=logger, ) diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 8af17c7643..d7d9aae024 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -32,6 +32,7 @@ class InvocationServices: logger: "Logger" model_manager: "ModelManagerServiceBase" processor: "InvocationProcessorABC" + performance_statistics: "InvocationStatsServiceBase" queue: "InvocationQueueABC" def __init__( @@ -47,6 +48,7 @@ class InvocationServices: logger: "Logger", model_manager: "ModelManagerServiceBase", processor: "InvocationProcessorABC", + performance_statistics: "InvocationStatsServiceBase", queue: "InvocationQueueABC", ): self.board_images = board_images @@ -61,4 +63,5 @@ class InvocationServices: self.logger = logger self.model_manager = model_manager self.processor = processor + self.performance_statistics = performance_statistics self.queue = queue diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 24a5662647..aca1dba550 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -3,99 +3,196 @@ """ Usage: -statistics = InvocationStats() # keep track of performance metrics -... -with statistics.collect_stats(invocation, graph_execution_state): - outputs = invocation.invoke( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - ) - ) -... + +statistics = InvocationStatsService(graph_execution_manager) +with statistics.collect_stats(invocation, graph_execution_state.id): + ... execute graphs... statistics.log_stats() Typical output: -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Node Calls Seconds -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> main_model_loader 1 0.006s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> clip_skip 1 0.005s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> compel 2 0.351s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> rand_int 1 0.001s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> range_of_size 1 0.001s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> iterate 1 0.001s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> noise 1 0.002s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> t2l 1 3.117s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> l2i 1 0.377s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> TOTAL: 3.865s -[2023-08-01 17:34:44,585]::[InvokeAI]::INFO --> Max VRAM used for execution: 3.12G. -[2023-08-01 17:34:44,586]::[InvokeAI]::INFO --> Current VRAM utilization 2.31G. +[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3 +[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used +[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s +[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G + +The abstract base class for this class is InvocationStatsServiceBase. An implementing class which +writes to the system log is stored in InvocationServices.performance_statistics. """ import time -from typing import Dict, List +from abc import ABC, abstractmethod +from contextlib import AbstractContextManager +from dataclasses import dataclass, field +from typing import Dict import torch -from .graph import GraphExecutionState -from .invocation_queue import InvocationQueueItem -from ..invocations.baseinvocation import BaseInvocation - import invokeai.backend.util.logging as logger +from ..invocations.baseinvocation import BaseInvocation +from .graph import GraphExecutionState +from .item_storage import ItemStorageABC -class InvocationStats: + +class InvocationStatsServiceBase(ABC): + "Abstract base class for recording node memory/time performance statistics" + + @abstractmethod + def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): + """ + Initialize the InvocationStatsService and reset counters to zero + :param graph_execution_manager: Graph execution manager for this session + """ + pass + + @abstractmethod + def collect_stats( + self, + invocation: BaseInvocation, + graph_execution_state_id: str, + ) -> AbstractContextManager: + """ + Return a context object that will capture the statistics on the execution + of invocaation. Use with: to place around the part of the code that executes the invocation. + :param invocation: BaseInvocation object from the current graph. + :param graph_execution_state: GraphExecutionState object from the current session. + """ + pass + + @abstractmethod + def reset_stats(self, graph_execution_state_id: str): + """ + Reset all statistics for the indicated graph + :param graph_execution_state_id + """ + pass + + @abstractmethod + def reset_all_stats(self): + """Zero all statistics""" + pass + + @abstractmethod + def update_invocation_stats( + self, + graph_id: str, + invocation_type: str, + time_used: float, + vram_used: float, + ): + """ + Add timing information on execution of a node. Usually + used internally. + :param graph_id: ID of the graph that is currently executing + :param invocation_type: String literal type of the node + :param time_used: Time used by node's exection (sec) + :param vram_used: Maximum VRAM used during exection (GB) + """ + pass + + @abstractmethod + def log_stats(self): + """ + Write out the accumulated statistics to the log or somewhere else. + """ + pass + + +@dataclass +class NodeStats: + """Class for tracking execution stats of an invocation node""" + + calls: int = 0 + time_used: float = 0.0 # seconds + max_vram: float = 0.0 # GB + + +@dataclass +class NodeLog: + """Class for tracking node usage""" + + # {node_type => NodeStats} + nodes: Dict[str, NodeStats] = field(default_factory=dict) + + +class InvocationStatsService(InvocationStatsServiceBase): """Accumulate performance information about a running graph. Collects time spent in each node, as well as the maximum and current VRAM utilisation for CUDA systems""" - def __init__(self): - self._stats: Dict[str, int] = {} + def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]): + self.graph_execution_manager = graph_execution_manager + # {graph_id => NodeLog} + self._stats: Dict[str, NodeLog] = {} class StatsContext: - def __init__(self, invocation: BaseInvocation, collector): + def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"): self.invocation = invocation self.collector = collector + self.graph_id = graph_id self.start_time = 0 def __enter__(self): self.start_time = time.time() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() def __exit__(self, *args): - self.collector.log_time(self.invocation.type, time.time() - self.start_time) + self.collector.update_invocation_stats( + self.graph_id, + self.invocation.type, + time.time() - self.start_time, + torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, + ) def collect_stats( self, invocation: BaseInvocation, - graph_execution_state: GraphExecutionState, + graph_execution_state_id: str, ) -> StatsContext: """ Return a context object that will capture the statistics. :param invocation: BaseInvocation object from the current graph. :param graph_execution_state: GraphExecutionState object from the current session. """ - if len(graph_execution_state.executed) == 0: # new graph is starting - self.reset_stats() - self._current_graph_state = graph_execution_state - sc = self.StatsContext(invocation, self) - return self.StatsContext(invocation, self) + if not self._stats.get(graph_execution_state_id): # first time we're seeing this + self._stats[graph_execution_state_id] = NodeLog() + return self.StatsContext(invocation, graph_execution_state_id, self) - def reset_stats(self): - """Zero the statistics. Ordinarily called internally.""" - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - self._stats: Dict[str, List[int, float]] = {} + def reset_all_stats(self): + """Zero all statistics""" + self._stats = {} - def log_time(self, invocation_type: str, time_used: float): + def reset_stats(self, graph_execution_id: str): + """Zero the statistics for the indicated graph.""" + try: + self._stats.pop(graph_execution_id) + except KeyError: + logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") + + def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float): """ Add timing information on execution of a node. Usually used internally. + :param graph_id: ID of the graph that is currently executing :param invocation_type: String literal type of the node :param time_used: Floating point seconds used by node's exection """ - if not self._stats.get(invocation_type): - self._stats[invocation_type] = [0, 0.0] - self._stats[invocation_type][0] += 1 - self._stats[invocation_type][1] += time_used + if not self._stats[graph_id].nodes.get(invocation_type): + self._stats[graph_id].nodes[invocation_type] = NodeStats() + stats = self._stats[graph_id].nodes[invocation_type] + stats.calls += 1 + stats.time_used += time_used + stats.max_vram = max(stats.max_vram, vram_used) def log_stats(self): """ @@ -103,13 +200,24 @@ class InvocationStats: Stats will only be printed if when the execution of the graph is complete. """ - if self._current_graph_state.is_complete(): - logger.info("Node Calls Seconds") - for node_type, (calls, time_used) in self._stats.items(): - logger.info(f"{node_type:<20} {calls:>5} {time_used:4.3f}s") + completed = set() + for graph_id, node_log in self._stats.items(): + current_graph_state = self.graph_execution_manager.get(graph_id) + if not current_graph_state.is_complete(): + continue - total_time = sum([ticks for _, ticks in self._stats.values()]) - logger.info(f"TOTAL: {total_time:4.3f}s") + total_time = 0 + logger.info(f"Graph stats: {graph_id}") + logger.info("Node Calls Seconds VRAM Used") + for node_type, stats in self._stats[graph_id].nodes.items(): + logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:4.3f}s {stats.max_vram:4.2f}G") + total_time += stats.time_used + + logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:4.3f}s") if torch.cuda.is_available(): - logger.info("Max VRAM used for execution: " + "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) + + completed.add(graph_id) + + for graph_id in completed: + del self._stats[graph_id] diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index e9511aa283..41170a304b 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -1,15 +1,15 @@ import time import traceback -from threading import Event, Thread, BoundedSemaphore - -from ..invocations.baseinvocation import InvocationContext -from .invocation_queue import InvocationQueueItem -from .invoker import InvocationProcessorABC, Invoker -from .invocation_stats import InvocationStats -from ..models.exceptions import CanceledException +from threading import BoundedSemaphore, Event, Thread import invokeai.backend.util.logging as logger +from ..invocations.baseinvocation import InvocationContext +from ..models.exceptions import CanceledException +from .invocation_queue import InvocationQueueItem +from .invocation_stats import InvocationStatsServiceBase +from .invoker import InvocationProcessorABC, Invoker + class DefaultInvocationProcessor(InvocationProcessorABC): __invoker_thread: Thread @@ -36,7 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC): def __process(self, stop_event: Event): try: self.__threadLimit.acquire() - statistics = InvocationStats() # keep track of performance metrics + statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics + while not stop_event.is_set(): try: queue_item: InvocationQueueItem = self.__invoker.services.queue.get() @@ -85,7 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: - with statistics.collect_stats(invocation, graph_execution_state): + with statistics.collect_stats(invocation, graph_execution_state.id): outputs = invocation.invoke( InvocationContext( services=self.__invoker.services, @@ -116,7 +117,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass except CanceledException: - statistics.reset_stats() + statistics.reset_stats(graph_execution_state.id) pass except Exception as e: @@ -138,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error_type=e.__class__.__name__, error=error, ) - statistics.reset_stats() + statistics.reset_stats(graph_execution_state.id) pass # Check queue to see if this is canceled, and skip if so From 3fc789a7eeb9efa6255e331ab7dda03097c451d6 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 2 Aug 2023 18:31:10 -0400 Subject: [PATCH 06/33] fix unit tests --- tests/nodes/test_graph_execution_state.py | 9 ++++++--- tests/nodes/test_invoker.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index e0ee120b54..248bc6fee1 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -16,6 +16,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.graph import ( Graph, CollectInvocation, @@ -41,6 +42,9 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + filename=sqlite_memory, table_name="graph_executions" + ) return InvocationServices( model_manager=None, # type: ignore events=TestEventService(), @@ -51,9 +55,8 @@ def mock_services() -> InvocationServices: board_images=None, # type: ignore queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), - graph_execution_manager=SqliteItemStorage[GraphExecutionState]( - filename=sqlite_memory, table_name="graph_executions" - ), + graph_execution_manager=graph_execution_manager, + performance_statistics=InvocationStatsService(graph_execution_manager), processor=DefaultInvocationProcessor(), configuration=None, # type: ignore ) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 8eba6d468f..5985c7e8bb 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -11,6 +11,7 @@ from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invoker import Invoker from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats import InvocationStatsService from invokeai.app.services.graph import ( Graph, GraphExecutionState, @@ -34,6 +35,9 @@ def simple_graph(): @pytest.fixture def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( + filename=sqlite_memory, table_name="graph_executions" + ) return InvocationServices( model_manager=None, # type: ignore events=TestEventService(), @@ -44,10 +48,9 @@ def mock_services() -> InvocationServices: board_images=None, # type: ignore queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), - graph_execution_manager=SqliteItemStorage[GraphExecutionState]( - filename=sqlite_memory, table_name="graph_executions" - ), + graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), + performance_statistics=InvocationStatsService(graph_execution_manager), configuration=None, # type: ignore ) From 921ccad04d2b903617cd4dd6d15c4f7fa8fe9c66 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 2 Aug 2023 18:41:43 -0400 Subject: [PATCH 07/33] added stats service to the cli_app startup --- invokeai/app/cli_app.py | 2 ++ scripts/dream.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index bad95bb559..4558c9219f 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService +from invokeai.app.services.invocation_stats import InvocationStatsService from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage @@ -311,6 +312,7 @@ def invoke_cli(): graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), + performance_statistics=InvocationStatsService(graph_execution_manager), logger=logger, configuration=config, ) diff --git a/scripts/dream.py b/scripts/dream.py index 12176db41e..d2735046a4 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -2,10 +2,12 @@ # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) import warnings -from invokeai.frontend.CLI import invokeai_command_line_interface as main warnings.warn( "dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API", DeprecationWarning, ) -main() + +from invokeai.app.cli_app import invoke_cli + +invoke_cli() From 0ba8a0ea6c4f874db22d91b3192f8c084f8a6fd2 Mon Sep 17 00:00:00 2001 From: Kevin Brack Date: Sun, 30 Jul 2023 13:05:29 -0500 Subject: [PATCH 08/33] Board assignment changing on click --- .../components/GallerySettingsPopover.tsx | 19 ++++++++++++++++--- .../features/gallery/store/gallerySlice.ts | 6 ++++++ .../web/src/features/gallery/store/types.ts | 1 + 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx index 21a580d9a9..04cc98edb7 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx @@ -8,6 +8,7 @@ import IAIPopover from 'common/components/IAIPopover'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISlider from 'common/components/IAISlider'; import { + autoAssignBoardOnClickChanged, setGalleryImageMinimumWidth, shouldAutoSwitchChanged, } from 'features/gallery/store/gallerySlice'; @@ -19,11 +20,16 @@ import BoardAutoAddSelect from './Boards/BoardAutoAddSelect'; const selector = createSelector( [stateSelector], (state) => { - const { galleryImageMinimumWidth, shouldAutoSwitch } = state.gallery; + const { + galleryImageMinimumWidth, + shouldAutoSwitch, + autoAssignBoardOnClick, + } = state.gallery; return { galleryImageMinimumWidth, shouldAutoSwitch, + autoAssignBoardOnClick, }; }, defaultSelectorOptions @@ -33,7 +39,7 @@ const GallerySettingsPopover = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { galleryImageMinimumWidth, shouldAutoSwitch } = + const { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick } = useAppSelector(selector); const handleChangeGalleryImageMinimumWidth = (v: number) => { @@ -69,7 +75,14 @@ const GallerySettingsPopover = () => { dispatch(shouldAutoSwitchChanged(e.target.checked)) } /> - + ) => + dispatch(autoAssignBoardOnClickChanged(e.target.checked)) + } + /> + {!autoAssignBoardOnClick && } ); diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 5eabe5de26..851e1b6c3b 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -8,6 +8,7 @@ export const initialGalleryState: GalleryState = { selection: [], shouldAutoSwitch: true, autoAddBoardId: undefined, + autoAssignBoardOnClick: true, galleryImageMinimumWidth: 96, selectedBoardId: undefined, galleryView: 'images', @@ -66,9 +67,13 @@ export const gallerySlice = createSlice({ setGalleryImageMinimumWidth: (state, action: PayloadAction) => { state.galleryImageMinimumWidth = action.payload; }, + autoAssignBoardOnClickChanged: (state, action: PayloadAction) => { + state.autoAssignBoardOnClick = action.payload; + }, boardIdSelected: (state, action: PayloadAction) => { state.selectedBoardId = action.payload; state.galleryView = 'images'; + state.autoAssignBoardOnClick && (state.autoAddBoardId = action.payload); }, isBatchEnabledChanged: (state, action: PayloadAction) => { state.isBatchEnabled = action.payload; @@ -140,6 +145,7 @@ export const { imageSelectionToggled, imageSelected, shouldAutoSwitchChanged, + autoAssignBoardOnClickChanged, setGalleryImageMinimumWidth, boardIdSelected, isBatchEnabledChanged, diff --git a/invokeai/frontend/web/src/features/gallery/store/types.ts b/invokeai/frontend/web/src/features/gallery/store/types.ts index d19a6fded3..298b792362 100644 --- a/invokeai/frontend/web/src/features/gallery/store/types.ts +++ b/invokeai/frontend/web/src/features/gallery/store/types.ts @@ -18,6 +18,7 @@ export type GalleryState = { selection: string[]; shouldAutoSwitch: boolean; autoAddBoardId: string | undefined; + autoAssignBoardOnClick: boolean; galleryImageMinimumWidth: number; selectedBoardId: BoardId; galleryView: GalleryView; From 450e95de592ac623bce8acca2ebeb60bbeb1a24a Mon Sep 17 00:00:00 2001 From: Kevin Brack Date: Sun, 30 Jul 2023 18:35:18 -0500 Subject: [PATCH 09/33] auto change board waiting for isReady --- .../Boards/BoardsList/GalleryBoard.tsx | 19 +++++++++++++++---- .../Boards/BoardsList/NoBoardBoard.tsx | 18 +++++++++++++----- .../features/gallery/store/gallerySlice.ts | 1 - 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 67c45c131b..a4124e2393 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -16,7 +16,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDroppable from 'common/components/IAIDroppable'; import SelectionOverlay from 'common/components/SelectionOverlay'; -import { boardIdSelected } from 'features/gallery/store/gallerySlice'; +import { + autoAddBoardIdChanged, + boardIdSelected, +} from 'features/gallery/store/gallerySlice'; import { memo, useCallback, useMemo, useState } from 'react'; import { FaUser } from 'react-icons/fa'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; @@ -24,6 +27,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { BoardDTO } from 'services/api/types'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; +import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; interface GalleryBoardProps { board: BoardDTO; @@ -41,15 +45,17 @@ const GalleryBoard = memo( ({ gallery }) => { const isSelectedForAutoAdd = board.board_id === gallery.autoAddBoardId; + const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; - return { isSelectedForAutoAdd }; + return { isSelectedForAutoAdd, autoAssignBoardOnClick }; }, defaultSelectorOptions ), [board.board_id] ); - const { isSelectedForAutoAdd } = useAppSelector(selector); + const { isSelectedForAutoAdd, autoAssignBoardOnClick } = + useAppSelector(selector); const [isHovered, setIsHovered] = useState(false); const handleMouseOver = useCallback(() => { setIsHovered(true); @@ -64,9 +70,14 @@ const GalleryBoard = memo( const { board_name, board_id } = board; const [localBoardName, setLocalBoardName] = useState(board_name); + const isReady = useIsReadyToInvoke(); + const handleSelectBoard = useCallback(() => { dispatch(boardIdSelected(board_id)); - }, [board_id, dispatch]); + if (autoAssignBoardOnClick && isReady) { + dispatch(autoAddBoardIdChanged(board_id)); + } + }, [board_id, autoAssignBoardOnClick, isReady, dispatch]); const [updateBoard, { isLoading: isUpdateBoardLoading }] = useUpdateBoardMutation(); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index ee1d8f6bea..3963fc04d9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -7,11 +7,15 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import InvokeAILogoImage from 'assets/images/logo.png'; import IAIDroppable from 'common/components/IAIDroppable'; import SelectionOverlay from 'common/components/SelectionOverlay'; -import { boardIdSelected } from 'features/gallery/store/gallerySlice'; +import { + boardIdSelected, + autoAddBoardIdChanged, +} from 'features/gallery/store/gallerySlice'; import { memo, useCallback, useMemo, useState } from 'react'; import { useBoardName } from 'services/api/hooks/useBoardName'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; +import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; interface Props { isSelected: boolean; } @@ -19,19 +23,23 @@ interface Props { const selector = createSelector( stateSelector, ({ gallery }) => { - const { autoAddBoardId } = gallery; - return { autoAddBoardId }; + const { autoAddBoardId, autoAssignBoardOnClick } = gallery; + return { autoAddBoardId, autoAssignBoardOnClick }; }, defaultSelectorOptions ); const NoBoardBoard = memo(({ isSelected }: Props) => { const dispatch = useAppDispatch(); - const { autoAddBoardId } = useAppSelector(selector); + const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector); const boardName = useBoardName(undefined); + const isReady = useIsReadyToInvoke(); const handleSelectBoard = useCallback(() => { dispatch(boardIdSelected(undefined)); - }, [dispatch]); + if (autoAssignBoardOnClick && isReady) { + dispatch(autoAddBoardIdChanged(undefined)); + } + }, [dispatch, autoAssignBoardOnClick, isReady]); const [isHovered, setIsHovered] = useState(false); const handleMouseOver = useCallback(() => { setIsHovered(true); diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 851e1b6c3b..9c65e818f4 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -73,7 +73,6 @@ export const gallerySlice = createSlice({ boardIdSelected: (state, action: PayloadAction) => { state.selectedBoardId = action.payload; state.galleryView = 'images'; - state.autoAssignBoardOnClick && (state.autoAddBoardId = action.payload); }, isBatchEnabledChanged: (state, action: PayloadAction) => { state.isBatchEnabled = action.payload; From 366952f810fcedd676b735046c9e384d08aa244f Mon Sep 17 00:00:00 2001 From: Kevin Brack Date: Sun, 30 Jul 2023 19:02:53 -0500 Subject: [PATCH 10/33] fix localization --- invokeai/frontend/web/public/locales/en.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index cf84e4d773..63380a19fa 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -124,7 +124,8 @@ "deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImagePermanent": "Deleted images cannot be restored.", "images": "Images", - "assets": "Assets" + "assets": "Assets", + "autoAssignBoardOnClick": "Auto-Assign Board on Click" }, "hotkeys": { "keyboardShortcuts": "Keyboard Shortcuts", From 87424be95d17434bb698bb4742acbaa7fb63d2d6 Mon Sep 17 00:00:00 2001 From: Kevin Brack Date: Mon, 31 Jul 2023 19:34:24 -0500 Subject: [PATCH 11/33] block auto add board change during generation. Switch condition to isProcessing --- .../components/Boards/BoardAutoAddSelect.tsx | 12 ++++++++---- .../Boards/BoardsList/GalleryBoard.tsx | 18 ++++++++++-------- .../Boards/BoardsList/NoBoardBoard.tsx | 14 +++++++------- .../components/GallerySettingsPopover.tsx | 2 +- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx index ad0e5ab80d..9f02a29f10 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx @@ -11,11 +11,14 @@ import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; const selector = createSelector( [stateSelector], - ({ gallery }) => { - const { autoAddBoardId } = gallery; + ({ gallery, system }) => { + const { autoAddBoardId, autoAssignBoardOnClick } = gallery; + const { isProcessing } = system; return { autoAddBoardId, + autoAssignBoardOnClick, + isProcessing, }; }, defaultSelectorOptions @@ -23,7 +26,8 @@ const selector = createSelector( const BoardAutoAddSelect = () => { const dispatch = useAppDispatch(); - const { autoAddBoardId } = useAppSelector(selector); + const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } = + useAppSelector(selector); const inputRef = useRef(null); const { boards, hasBoards } = useListAllBoardsQuery(undefined, { selectFromResult: ({ data }) => { @@ -67,7 +71,7 @@ const BoardAutoAddSelect = () => { data={boards} nothingFound="No matching Boards" itemComponent={IAIMantineSelectItemWithTooltip} - disabled={!hasBoards} + disabled={!hasBoards || autoAssignBoardOnClick || isProcessing} filter={(value, item: SelectItem) => item.label?.toLowerCase().includes(value.toLowerCase().trim()) || item.value.toLowerCase().includes(value.toLowerCase().trim()) diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index a4124e2393..3b591ee00f 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -27,7 +27,6 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { BoardDTO } from 'services/api/types'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; interface GalleryBoardProps { board: BoardDTO; @@ -42,19 +41,24 @@ const GalleryBoard = memo( () => createSelector( stateSelector, - ({ gallery }) => { + ({ gallery, system }) => { const isSelectedForAutoAdd = board.board_id === gallery.autoAddBoardId; const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; + const isProcessing = system.isProcessing; - return { isSelectedForAutoAdd, autoAssignBoardOnClick }; + return { + isSelectedForAutoAdd, + autoAssignBoardOnClick, + isProcessing, + }; }, defaultSelectorOptions ), [board.board_id] ); - const { isSelectedForAutoAdd, autoAssignBoardOnClick } = + const { isSelectedForAutoAdd, autoAssignBoardOnClick, isProcessing } = useAppSelector(selector); const [isHovered, setIsHovered] = useState(false); const handleMouseOver = useCallback(() => { @@ -70,14 +74,12 @@ const GalleryBoard = memo( const { board_name, board_id } = board; const [localBoardName, setLocalBoardName] = useState(board_name); - const isReady = useIsReadyToInvoke(); - const handleSelectBoard = useCallback(() => { dispatch(boardIdSelected(board_id)); - if (autoAssignBoardOnClick && isReady) { + if (autoAssignBoardOnClick && !isProcessing) { dispatch(autoAddBoardIdChanged(board_id)); } - }, [board_id, autoAssignBoardOnClick, isReady, dispatch]); + }, [board_id, autoAssignBoardOnClick, isProcessing, dispatch]); const [updateBoard, { isLoading: isUpdateBoardLoading }] = useUpdateBoardMutation(); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index 3963fc04d9..118b2108f7 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -15,31 +15,31 @@ import { memo, useCallback, useMemo, useState } from 'react'; import { useBoardName } from 'services/api/hooks/useBoardName'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; -import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; interface Props { isSelected: boolean; } const selector = createSelector( stateSelector, - ({ gallery }) => { + ({ gallery, system }) => { const { autoAddBoardId, autoAssignBoardOnClick } = gallery; - return { autoAddBoardId, autoAssignBoardOnClick }; + const { isProcessing } = system; + return { autoAddBoardId, autoAssignBoardOnClick, isProcessing }; }, defaultSelectorOptions ); const NoBoardBoard = memo(({ isSelected }: Props) => { const dispatch = useAppDispatch(); - const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector); + const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } = + useAppSelector(selector); const boardName = useBoardName(undefined); - const isReady = useIsReadyToInvoke(); const handleSelectBoard = useCallback(() => { dispatch(boardIdSelected(undefined)); - if (autoAssignBoardOnClick && isReady) { + if (autoAssignBoardOnClick && !isProcessing) { dispatch(autoAddBoardIdChanged(undefined)); } - }, [dispatch, autoAssignBoardOnClick, isReady]); + }, [dispatch, autoAssignBoardOnClick, isProcessing]); const [isHovered, setIsHovered] = useState(false); const handleMouseOver = useCallback(() => { setIsHovered(true); diff --git a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx index 04cc98edb7..796cc542e7 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx @@ -82,7 +82,7 @@ const GallerySettingsPopover = () => { dispatch(autoAssignBoardOnClickChanged(e.target.checked)) } /> - {!autoAssignBoardOnClick && } + ); From 26ef5249b1902132493ef6cbc5f1c8d6c4c7bdd8 Mon Sep 17 00:00:00 2001 From: Kevin Brack Date: Mon, 31 Jul 2023 19:44:41 -0500 Subject: [PATCH 12/33] guard board switching in board context menu --- .../gallery/components/Boards/BoardContextMenu.tsx | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index 35fcbd87f7..2774288612 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -25,14 +25,17 @@ const BoardContextMenu = memo( const selector = useMemo( () => - createSelector(stateSelector, ({ gallery }) => { + createSelector(stateSelector, ({ gallery, system }) => { const isAutoAdd = gallery.autoAddBoardId === board_id; - return { isAutoAdd }; + const isProcessing = system.isProcessing; + const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; + return { isAutoAdd, isProcessing, autoAssignBoardOnClick }; }), [board_id] ); - const { isAutoAdd } = useAppSelector(selector); + const { isAutoAdd, isProcessing, autoAssignBoardOnClick } = + useAppSelector(selector); const boardName = useBoardName(board_id); const handleSetAutoAdd = useCallback(() => { @@ -59,7 +62,7 @@ const BoardContextMenu = memo( } - isDisabled={isAutoAdd} + isDisabled={isAutoAdd || isProcessing || autoAssignBoardOnClick} onClick={handleSetAutoAdd} > Auto-add to this Board From 7021467048925d413e9fd132545491ee04ac603e Mon Sep 17 00:00:00 2001 From: Eugene Brodsky Date: Wed, 2 Aug 2023 19:46:02 -0400 Subject: [PATCH 13/33] (ci) do not install all dependencies when running static checks (#4036) Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --- .github/workflows/style-checks.yml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/style-checks.yml b/.github/workflows/style-checks.yml index 8aceb6469e..0bb19e95e5 100644 --- a/.github/workflows/style-checks.yml +++ b/.github/workflows/style-checks.yml @@ -1,13 +1,15 @@ -name: Black # TODO: add isort and flake8 later +name: style checks +# just formatting for now +# TODO: add isort and flake8 later on: - pull_request: {} + pull_request: push: - branches: master + branches: main tags: "*" jobs: - test: + black: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -19,8 +21,7 @@ jobs: - name: Install dependencies with pip run: | - pip install --upgrade pip wheel - pip install .[test] + pip install black # - run: isort --check-only . - run: black --check . From 4e0949fa5503c4ed2a61641aa9deb0f3e1134fb4 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 30 Jul 2023 20:12:53 +0200 Subject: [PATCH 14/33] fix .swap() by reverting improperly merged @classmethod change --- .../diffusion/shared_invokeai_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 272518e928..c01cf82c57 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -78,10 +78,9 @@ class InvokeAIDiffuserComponent: self.cross_attention_control_context = None self.sequential_guidance = config.sequential_guidance - @classmethod @contextmanager def custom_attention_context( - cls, + self, unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int, @@ -91,18 +90,19 @@ class InvokeAIDiffuserComponent: old_attn_processors = unet.attn_processors # Load lora conditions into the model if extra_conditioning_info.wants_cross_attention_control: - cross_attention_control_context = Context( + self.cross_attention_control_context = Context( arguments=extra_conditioning_info.cross_attention_control_args, step_count=step_count, ) setup_cross_attention_control_attention_processors( unet, - cross_attention_control_context, + self.cross_attention_control_context, ) try: yield None finally: + self.cross_attention_control_context = None if old_attn_processors is not None: unet.set_attn_processor(old_attn_processors) # TODO resuscitate attention map saving From 118d5b387b05c1e906cb186c1a7540379e13cc01 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 1 Aug 2023 16:22:20 -0400 Subject: [PATCH 15/33] deploy: refactor github workflows Currently we use some workflow trigger conditionals to run either a real test workflow (installing the app and running it) or a fake workflow, disguised as the real one, that just auto-passes. This change refactors the workflow to use a single workflow that can be skipped, using another github action to determine which things to run depending on the paths changed. --- .github/workflows/style-checks.yml | 1 - .github/workflows/test-invoke-pip-skip.yml | 50 ---------------------- .github/workflows/test-invoke-pip.yml | 24 +++++++---- 3 files changed, 15 insertions(+), 60 deletions(-) delete mode 100644 .github/workflows/test-invoke-pip-skip.yml diff --git a/.github/workflows/style-checks.yml b/.github/workflows/style-checks.yml index 0bb19e95e5..d29b489418 100644 --- a/.github/workflows/style-checks.yml +++ b/.github/workflows/style-checks.yml @@ -6,7 +6,6 @@ on: pull_request: push: branches: main - tags: "*" jobs: black: diff --git a/.github/workflows/test-invoke-pip-skip.yml b/.github/workflows/test-invoke-pip-skip.yml deleted file mode 100644 index 004b46d5a8..0000000000 --- a/.github/workflows/test-invoke-pip-skip.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Test invoke.py pip - -# This is a dummy stand-in for the actual tests -# we don't need to run python tests on non-Python changes -# But PRs require passing tests to be mergeable - -on: - pull_request: - paths: - - '**' - - '!pyproject.toml' - - '!invokeai/**' - - '!tests/**' - - 'invokeai/frontend/web/**' - merge_group: - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - matrix: - if: github.event.pull_request.draft == false - strategy: - matrix: - python-version: - - '3.10' - pytorch: - - linux-cuda-11_7 - - linux-rocm-5_2 - - linux-cpu - - macos-default - - windows-cpu - include: - - pytorch: linux-cuda-11_7 - os: ubuntu-22.04 - - pytorch: linux-rocm-5_2 - os: ubuntu-22.04 - - pytorch: linux-cpu - os: ubuntu-22.04 - - pytorch: macos-default - os: macOS-12 - - pytorch: windows-cpu - os: windows-2022 - name: ${{ matrix.pytorch }} on ${{ matrix.python-version }} - runs-on: ${{ matrix.os }} - steps: - - name: skip - run: echo "no build required" diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index 40be0a529e..6086d10069 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -3,16 +3,7 @@ on: push: branches: - 'main' - paths: - - 'pyproject.toml' - - 'invokeai/**' - - '!invokeai/frontend/web/**' pull_request: - paths: - - 'pyproject.toml' - - 'invokeai/**' - - 'tests/**' - - '!invokeai/frontend/web/**' types: - 'ready_for_review' - 'opened' @@ -65,10 +56,23 @@ jobs: id: checkout-sources uses: actions/checkout@v3 + - name: Check for changed python files + id: changed-files + uses: tj-actions/changed-files@v37 + with: + files_yaml: | + python: + - 'pyproject.toml' + - 'invokeai/**' + - '!invokeai/frontend/web/**' + - 'tests/**' + - name: set test prompt to main branch validation + if: steps.changed-files.outputs.python_any_changed == 'true' run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }} - name: setup python + if: steps.changed-files.outputs.python_any_changed == 'true' uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -76,6 +80,7 @@ jobs: cache-dependency-path: pyproject.toml - name: install invokeai + if: steps.changed-files.outputs.python_any_changed == 'true' env: PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }} run: > @@ -83,6 +88,7 @@ jobs: --editable=".[test]" - name: run pytest + if: steps.changed-files.outputs.python_any_changed == 'true' id: run-pytest run: pytest From a6f9396a3021f561dda0ddd5e35bdc2d07207bbd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 1 Aug 2023 18:08:17 +1000 Subject: [PATCH 16/33] fix(db): retrieve metadata even when no `session_id` this was unnecessarily skipped if there was no `session_id`. --- invokeai/app/services/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index f8376eb626..2240846dac 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -289,9 +289,10 @@ class ImageService(ImageServiceABC): def get_metadata(self, image_name: str) -> Optional[ImageMetadata]: try: image_record = self._services.image_records.get(image_name) + metadata = self._services.image_records.get_metadata(image_name) if not image_record.session_id: - return ImageMetadata() + return ImageMetadata(metadata=metadata) session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) graph = None @@ -303,7 +304,6 @@ class ImageService(ImageServiceABC): self._services.logger.warn(f"Failed to parse session graph: {e}") graph = None - metadata = self._services.image_records.get_metadata(image_name) return ImageMetadata(graph=graph, metadata=metadata) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") From 5c9787c14540190a4960edd483cf328da2097012 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 2 Aug 2023 09:46:29 -0400 Subject: [PATCH 17/33] add project-id header to requests --- .../frontend/web/src/app/components/InvokeAIUI.tsx | 11 +++++++++-- invokeai/frontend/web/src/services/api/client.ts | 12 ++++++++++-- invokeai/frontend/web/src/services/api/index.ts | 6 +++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 7df390bce6..d6e6c42728 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -13,7 +13,7 @@ import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import Loading from '../../common/components/Loading/Loading'; import { Middleware } from '@reduxjs/toolkit'; -import { $authToken, $baseUrl } from 'services/api/client'; +import { $authToken, $baseUrl, $projectId } from 'services/api/client'; import { socketMiddleware } from 'services/events/middleware'; import '../../i18n'; import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; @@ -37,6 +37,7 @@ const InvokeAIUI = ({ config, headerComponent, middleware, + projectId, }: Props) => { useEffect(() => { // configure API client token @@ -49,6 +50,11 @@ const InvokeAIUI = ({ $baseUrl.set(apiUrl); } + // configure API client project header + if (apiUrl) { + $projectId.set(projectId); + } + // reset dynamically added middlewares resetMiddlewares(); @@ -68,8 +74,9 @@ const InvokeAIUI = ({ // Reset the API client token and base url on unmount $baseUrl.set(undefined); $authToken.set(undefined); + $projectId.set(undefined); }; - }, [apiUrl, token, middleware]); + }, [apiUrl, token, middleware, projectId]); return ( diff --git a/invokeai/frontend/web/src/services/api/client.ts b/invokeai/frontend/web/src/services/api/client.ts index dd4caa460e..87deda7d36 100644 --- a/invokeai/frontend/web/src/services/api/client.ts +++ b/invokeai/frontend/web/src/services/api/client.ts @@ -16,6 +16,11 @@ export const $authToken = atom(); */ export const $baseUrl = atom(); +/** + * The optional project-id header. + */ +export const $projectId = atom(); + /** * Autogenerated, type-safe fetch client for the API. Used when RTK Query is not an option. * Dynamically updates when the token or base url changes. @@ -24,9 +29,12 @@ export const $baseUrl = atom(); * @example * const { get, post, del } = $client.get(); */ -export const $client = computed([$authToken, $baseUrl], (authToken, baseUrl) => +export const $client = computed([$authToken, $baseUrl, $projectId], (authToken, baseUrl, projectId) => createClient({ - headers: authToken ? { Authorization: `Bearer ${authToken}` } : {}, + headers: { + ...(authToken ? { Authorization: `Bearer ${authToken}` } : {}), + ...(projectId ? { "project-id": projectId } : {}) + }, // do not include `api/v1` in the base url for this client baseUrl: `${baseUrl ?? ''}`, }) diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 0a0391898c..a9de7130c9 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -6,7 +6,7 @@ import { createApi, fetchBaseQuery, } from '@reduxjs/toolkit/query/react'; -import { $authToken, $baseUrl } from 'services/api/client'; +import { $authToken, $baseUrl, $projectId } from 'services/api/client'; export const tagTypes = [ 'Board', @@ -30,6 +30,7 @@ const dynamicBaseQuery: BaseQueryFn< > = async (args, api, extraOptions) => { const baseUrl = $baseUrl.get(); const authToken = $authToken.get(); + const projectId = $projectId.get(); const rawBaseQuery = fetchBaseQuery({ baseUrl: `${baseUrl ?? ''}/api/v1`, @@ -37,6 +38,9 @@ const dynamicBaseQuery: BaseQueryFn< if (authToken) { headers.set('Authorization', `Bearer ${authToken}`); } + if (projectId) { + headers.set("project-id", projectId) + } return headers; }, From b3b94b5a8d20daf428791b8dc15274c4cd9871d6 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 2 Aug 2023 09:52:54 -0400 Subject: [PATCH 18/33] use correct prop --- invokeai/frontend/web/src/app/components/InvokeAIUI.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index d6e6c42728..cffbaa5574 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -51,7 +51,7 @@ const InvokeAIUI = ({ } // configure API client project header - if (apiUrl) { + if (projectId) { $projectId.set(projectId); } From eeef1e08f827eb404bd8b53c91007e089ab8fd96 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 30 Jul 2023 18:20:12 -0400 Subject: [PATCH 19/33] restore ability to convert merged inpaint .safetensors files --- invokeai/backend/model_management/models/base.py | 5 +++-- invokeai/backend/model_management/models/stable_diffusion.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index e6a20e79ec..1219d4277d 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -292,8 +292,9 @@ class DiffusersModel(ModelBase): ) break except Exception as e: - # print("====ERR LOAD====") - # print(f"{variant}: {e}") + if not str(e).startswith('Error no file'): + print("====ERR LOAD====") + print(f"{variant}: {e}") pass else: raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index d81b0150e5..9e0c130e6a 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -4,6 +4,7 @@ from enum import Enum from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union +from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline from .base import ( ModelConfigBase, BaseModelType, @@ -21,7 +22,6 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf - class StableDiffusion1ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" @@ -263,6 +263,8 @@ def _convert_ckpt_and_cache( weights = app_config.models_path / model_config.path config_file = app_config.root_path / model_config.config output_path = Path(output_path) + variant = model_config.variant + pipeline_class = StableDiffusionInpaintPipeline if variant=='inpaint' else StableDiffusionPipeline # return cached version if it exists if output_path.exists(): @@ -289,6 +291,7 @@ def _convert_ckpt_and_cache( original_config_file=config_file, extract_ema=True, scan_needed=True, + pipeline_class=pipeline_class, from_safetensors=weights.suffix == ".safetensors", precision=torch_dtype(choose_torch_device()), **kwargs, From e080fd1e08a676d412108bb2ecf27e559c0bfc70 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 30 Jul 2023 19:18:05 -0400 Subject: [PATCH 20/33] blackify --- invokeai/backend/model_management/models/base.py | 2 +- invokeai/backend/model_management/models/stable_diffusion.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 1219d4277d..d335b645c8 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -292,7 +292,7 @@ class DiffusersModel(ModelBase): ) break except Exception as e: - if not str(e).startswith('Error no file'): + if not str(e).startswith("Error no file"): print("====ERR LOAD====") print(f"{variant}: {e}") pass diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 9e0c130e6a..a112e8bc96 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -22,6 +22,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf + class StableDiffusion1ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" @@ -264,7 +265,7 @@ def _convert_ckpt_and_cache( config_file = app_config.root_path / model_config.config output_path = Path(output_path) variant = model_config.variant - pipeline_class = StableDiffusionInpaintPipeline if variant=='inpaint' else StableDiffusionPipeline + pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline # return cached version if it exists if output_path.exists(): From bf94412d1499e48d6bd9dfa415ff7f1aeead6710 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Jul 2023 18:16:52 +1000 Subject: [PATCH 21/33] feat: add multi-select to gallery multi-select actions include: - drag to board to move all to that board - right click to add all to board or delete all backend changes: - add routes for changing board for list of image names, deleting list of images - change image-specific routes to `images/i/{image_name}` to not clobber other routes (like `images/upload`, `images/delete`) - subclass pydantic `BaseModel` as `BaseModelExcludeNull`, which excludes null values when calling `dict()` on the model. this fixes inconsistent types related to JSON parsing null values into `null` instead of `undefined` - remove `board_id` from `remove_image_from_board` frontend changes: - multi-selection stuff uses `ImageDTO[]` as payloads, for dnd and other mutations. this gives us access to image `board_id`s when hitting routes, and enables efficient cache updates. - consolidate change board and delete image modals to handle single and multiples - board totals are now re-fetched on mutation and not kept in sync manually - was way too tedious to do this - fixed warning about nested `

` elements - closes #4088 , need to handle case when `autoAddBoardId` is `"none"` - add option to show gallery image delete button on every gallery image frontend refactors/organisation: - make typegen script js instead of ts - enable `noUncheckedIndexedAccess` to help avoid bugs when indexing into arrays, many small changes needed to satisfy TS after this - move all image-related endpoints into `endpoints/images.ts`, its a big file now, but this fixes a number of circular dependency issues that were otherwise felt impossible to resolve --- invokeai/app/api/routers/board_images.py | 92 +- invokeai/app/api/routers/images.py | 38 +- invokeai/app/invocations/metadata.py | 9 +- .../services/board_image_record_storage.py | 6 +- invokeai/app/services/board_images.py | 4 +- invokeai/app/services/models/board_image.py | 8 + invokeai/app/services/models/board_record.py | 5 +- invokeai/app/services/models/image_record.py | 15 +- invokeai/app/services/urls.py | 4 +- invokeai/app/util/model_exclude_null.py | 23 + invokeai/frontend/web/package.json | 2 +- .../web/scripts/{typegen.ts => typegen.js} | 0 .../frontend/web/src/app/components/App.tsx | 6 +- .../app/components/ImageDnd/DragPreview.tsx | 4 +- .../components/ImageDnd/ImageDndContext.tsx | 25 +- .../app/components/ImageDnd/typesafeDnd.tsx | 69 +- .../web/src/app/components/InvokeAIUI.tsx | 13 +- .../app/contexts/AddImageToBoardContext.tsx | 91 -- .../contexts/ImageUploaderTriggerContext.ts | 8 - .../enhancers/reduxRemember/serialize.ts | 2 +- .../middleware/listenerMiddleware/index.ts | 6 +- .../addFirstListImagesListener.ts.ts | 11 +- .../listeners/appConfigReceived.ts | 4 +- .../listeners/boardAndImagesDeleted.ts | 6 +- .../listeners/boardIdSelected.ts | 6 +- .../listeners/canvasSavedToGallery.ts | 4 +- .../listeners/controlNetAutoProcess.ts | 11 +- .../listeners/controlNetImageProcessed.ts | 2 +- .../listeners/imageDeleted.ts | 113 ++- .../listeners/imageDropped.ts | 216 ++--- .../listeners/imageToDeleteSelected.ts | 29 +- .../listeners/imageUploaded.ts | 14 +- .../listeners/modelsLoaded.ts | 7 +- .../socketio/socketInvocationComplete.ts | 16 +- invokeai/frontend/web/src/app/store/store.ts | 8 +- .../src/common/components/IAIDropOverlay.tsx | 6 +- .../components/IAIMantineSearchableSelect.tsx | 4 +- .../src/common/components/ImageUploader.tsx | 2 +- .../src/common/hooks/useImageUploadButton.tsx | 2 +- .../canvas/hooks/useColorUnderCursor.ts | 4 + .../src/features/canvas/store/canvasSlice.ts | 9 +- .../components/ChangeBoardModal.tsx | 132 +++ .../changeBoardModal/store/initialState.ts | 6 + .../features/changeBoardModal/store/slice.ts | 25 + .../features/changeBoardModal/store/types.ts | 6 + .../controlNet/components/ControlNet.tsx | 38 +- .../components/ControlNetImagePreview.tsx | 63 +- .../ControlNetProcessorComponent.tsx | 28 +- .../ParamControlNetShouldAutoConfig.tsx | 29 +- .../parameters/ParamControlNetBeginEnd.tsx | 41 +- .../parameters/ParamControlNetControlMode.tsx | 27 +- .../parameters/ParamControlNetModel.tsx | 35 +- .../ParamControlNetProcessorSelect.tsx | 26 +- .../parameters/ParamControlNetResizeMode.tsx | 27 +- .../parameters/ParamControlNetWeight.tsx | 30 +- .../features/controlNet/store/constants.ts | 2 +- .../controlNet/store/controlNetSlice.ts | 159 ++-- .../components/DeleteImageButton.tsx | 0 .../components/DeleteImageModal.tsx | 51 +- .../components/ImageUsageMessage.tsx | 0 .../store/actions.ts | 6 +- .../deleteImageModal/store/initialState.ts | 6 + .../store/selectors.ts} | 14 +- .../features/deleteImageModal/store/slice.ts | 28 + .../features/deleteImageModal/store/types.ts | 13 + .../components/Boards/BoardAutoAddSelect.tsx | 2 +- .../components/Boards/BoardContextMenu.tsx | 3 +- .../Boards/BoardsList/BatchBoard.tsx | 43 - .../Boards/BoardsList/BoardsList.tsx | 15 +- .../Boards/BoardsList/BoardsSearch.tsx | 18 +- .../Boards/BoardsList/GalleryBoard.tsx | 241 +++--- .../Boards/BoardsList/NoBoardBoard.tsx | 43 +- .../components/Boards/DeleteBoardModal.tsx | 25 +- .../Boards/UpdateImageBoardModal.tsx | 93 --- .../CurrentImage/CurrentImageButtons.tsx | 12 +- .../CurrentImage/CurrentImagePreview.tsx | 8 +- .../components/GallerySettingsPopover.tsx | 58 +- .../ImageContextMenu/ImageContextMenu.tsx | 55 +- .../MultipleSelectionMenuItems.tsx | 36 +- .../SingleSelectionMenuItems.tsx | 80 +- .../components/ImageGalleryContent.tsx | 13 +- .../components/ImageGrid/BatchImage.tsx | 122 --- .../components/ImageGrid/BatchImageGrid.tsx | 87 -- .../components/ImageGrid/GalleryImage.tsx | 63 +- .../ImageMetadataActions.tsx | 2 +- .../gallery/hooks/useMultiselect.ts.ts | 93 +++ .../gallery/hooks/useNextPrevImage.ts | 62 +- .../web/src/features/gallery/store/actions.ts | 2 +- .../src/features/gallery/store/boardSlice.ts | 29 - .../gallery/store/gallerySelectors.ts | 4 +- .../features/gallery/store/gallerySlice.ts | 111 +-- .../web/src/features/gallery/store/types.ts | 13 +- .../imageDeletion/store/imageDeletionSlice.ts | 37 - .../src/features/imageDeletion/store/types.ts | 6 - .../web/src/features/lora/store/loraSlice.ts | 12 +- .../nodes/components/search/NodeSearch.tsx | 10 +- .../src/features/nodes/store/nodesSlice.ts | 24 +- .../nodes/util/fieldTemplateBuilders.ts | 25 +- .../ControlNet/ParamControlNetCollapse.tsx | 11 +- .../parameters/hooks/useRecallParameters.ts | 2 +- .../AddModelsPanel/AdvancedAddCheckpoint.tsx | 4 +- .../AddModelsPanel/AdvancedAddDiffusers.tsx | 2 +- .../subpanels/MergeModelsPanel.tsx | 9 +- .../ModelManagerPanel/ModelConvert.tsx | 16 +- .../web/src/features/ui/store/uiSelectors.ts | 4 +- .../src/services/api/endpoints/boardImages.ts | 36 - .../web/src/services/api/endpoints/boards.ts | 223 +---- .../web/src/services/api/endpoints/images.ts | 786 +++++++++++++----- .../web/src/services/api/endpoints/models.ts | 15 +- .../src/services/api/hooks/useBoardName.ts | 2 +- .../src/services/api/hooks/useBoardTotal.ts | 2 +- .../frontend/web/src/services/api/schema.d.ts | 302 +++++-- .../src/services/api/{types.d.ts => types.ts} | 34 +- .../frontend/web/src/services/api/util.ts | 56 ++ .../src/theme/util/generateColorPalette.ts | 2 +- invokeai/frontend/web/tsconfig.json | 2 + 116 files changed, 2470 insertions(+), 2181 deletions(-) create mode 100644 invokeai/app/services/models/board_image.py create mode 100644 invokeai/app/util/model_exclude_null.py rename invokeai/frontend/web/scripts/{typegen.ts => typegen.js} (100%) delete mode 100644 invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx delete mode 100644 invokeai/frontend/web/src/app/contexts/ImageUploaderTriggerContext.ts create mode 100644 invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx create mode 100644 invokeai/frontend/web/src/features/changeBoardModal/store/initialState.ts create mode 100644 invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts create mode 100644 invokeai/frontend/web/src/features/changeBoardModal/store/types.ts rename invokeai/frontend/web/src/features/{imageDeletion => deleteImageModal}/components/DeleteImageButton.tsx (100%) rename invokeai/frontend/web/src/features/{imageDeletion => deleteImageModal}/components/DeleteImageModal.tsx (70%) rename invokeai/frontend/web/src/features/{imageDeletion => deleteImageModal}/components/ImageUsageMessage.tsx (100%) rename invokeai/frontend/web/src/features/{imageDeletion => deleteImageModal}/store/actions.ts (65%) create mode 100644 invokeai/frontend/web/src/features/deleteImageModal/store/initialState.ts rename invokeai/frontend/web/src/features/{imageDeletion/store/imageDeletionSelectors.ts => deleteImageModal/store/selectors.ts} (84%) create mode 100644 invokeai/frontend/web/src/features/deleteImageModal/store/slice.ts create mode 100644 invokeai/frontend/web/src/features/deleteImageModal/store/types.ts delete mode 100644 invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BatchBoard.tsx delete mode 100644 invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx delete mode 100644 invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImage.tsx delete mode 100644 invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImageGrid.tsx create mode 100644 invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts delete mode 100644 invokeai/frontend/web/src/features/gallery/store/boardSlice.ts delete mode 100644 invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts delete mode 100644 invokeai/frontend/web/src/features/imageDeletion/store/types.ts delete mode 100644 invokeai/frontend/web/src/services/api/endpoints/boardImages.ts rename invokeai/frontend/web/src/services/api/{types.d.ts => types.ts} (89%) create mode 100644 invokeai/frontend/web/src/services/api/util.ts diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py index 6cb073ca7c..73607ecb7d 100644 --- a/invokeai/app/api/routers/board_images.py +++ b/invokeai/app/api/routers/board_images.py @@ -1,24 +1,30 @@ -from fastapi import Body, HTTPException, Path, Query +from fastapi import Body, HTTPException from fastapi.routing import APIRouter -from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges -from invokeai.app.services.image_record_storage import OffsetPaginatedResults -from invokeai.app.services.models.board_record import BoardDTO -from invokeai.app.services.models.image_record import ImageDTO +from pydantic import BaseModel, Field from ..dependencies import ApiDependencies board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) +class AddImagesToBoardResult(BaseModel): + board_id: str = Field(description="The id of the board the images were added to") + added_image_names: list[str] = Field(description="The image names that were added to the board") + + +class RemoveImagesFromBoardResult(BaseModel): + removed_image_names: list[str] = Field(description="The image names that were removed from their board") + + @board_images_router.post( "/", - operation_id="create_board_image", + operation_id="add_image_to_board", responses={ 201: {"description": "The image was added to a board successfully"}, }, status_code=201, ) -async def create_board_image( +async def add_image_to_board( board_id: str = Body(description="The id of the board to add to"), image_name: str = Body(description="The name of the image to add"), ): @@ -29,26 +35,78 @@ async def create_board_image( ) return result except Exception as e: - raise HTTPException(status_code=500, detail="Failed to add to board") + raise HTTPException(status_code=500, detail="Failed to add image to board") @board_images_router.delete( "/", - operation_id="remove_board_image", + operation_id="remove_image_from_board", responses={ 201: {"description": "The image was removed from the board successfully"}, }, status_code=201, ) -async def remove_board_image( - board_id: str = Body(description="The id of the board"), - image_name: str = Body(description="The name of the image to remove"), +async def remove_image_from_board( + image_name: str = Body(description="The name of the image to remove", embed=True), ): - """Deletes a board_image""" + """Removes an image from its board, if it had one""" try: - result = ApiDependencies.invoker.services.board_images.remove_image_from_board( - board_id=board_id, image_name=image_name - ) + result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) return result except Exception as e: - raise HTTPException(status_code=500, detail="Failed to update board") + raise HTTPException(status_code=500, detail="Failed to remove image from board") + + +@board_images_router.post( + "/batch", + operation_id="add_images_to_board", + responses={ + 201: {"description": "Images were added to board successfully"}, + }, + status_code=201, + response_model=AddImagesToBoardResult, +) +async def add_images_to_board( + board_id: str = Body(description="The id of the board to add to"), + image_names: list[str] = Body(description="The names of the images to add", embed=True), +) -> AddImagesToBoardResult: + """Adds a list of images to a board""" + try: + added_image_names: list[str] = [] + for image_name in image_names: + try: + ApiDependencies.invoker.services.board_images.add_image_to_board( + board_id=board_id, image_name=image_name + ) + added_image_names.append(image_name) + except: + pass + return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names) + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to add images to board") + + +@board_images_router.post( + "/batch/delete", + operation_id="remove_images_from_board", + responses={ + 201: {"description": "Images were removed from board successfully"}, + }, + status_code=201, + response_model=RemoveImagesFromBoardResult, +) +async def remove_images_from_board( + image_names: list[str] = Body(description="The names of the images to remove", embed=True), +) -> RemoveImagesFromBoardResult: + """Removes a list of images from their board, if they had one""" + try: + removed_image_names: list[str] = [] + for image_name in image_names: + try: + ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) + removed_image_names.append(image_name) + except: + pass + return RemoveImagesFromBoardResult(removed_image_names=removed_image_names) + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to remove images from board") diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 498a1139e4..aff409e9e5 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -5,6 +5,7 @@ from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadF from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image +from pydantic import BaseModel, Field from invokeai.app.invocations.metadata import ImageMetadata from invokeai.app.models.image import ImageCategory, ResourceOrigin @@ -25,7 +26,7 @@ IMAGE_MAX_AGE = 31536000 @images_router.post( - "/", + "/upload", operation_id="upload_image", responses={ 201: {"description": "The image was uploaded successfully"}, @@ -77,7 +78,7 @@ async def upload_image( raise HTTPException(status_code=500, detail="Failed to create image") -@images_router.delete("/{image_name}", operation_id="delete_image") +@images_router.delete("/i/{image_name}", operation_id="delete_image") async def delete_image( image_name: str = Path(description="The name of the image to delete"), ) -> None: @@ -103,7 +104,7 @@ async def clear_intermediates() -> int: @images_router.patch( - "/{image_name}", + "/i/{image_name}", operation_id="update_image", response_model=ImageDTO, ) @@ -120,7 +121,7 @@ async def update_image( @images_router.get( - "/{image_name}", + "/i/{image_name}", operation_id="get_image_dto", response_model=ImageDTO, ) @@ -136,7 +137,7 @@ async def get_image_dto( @images_router.get( - "/{image_name}/metadata", + "/i/{image_name}/metadata", operation_id="get_image_metadata", response_model=ImageMetadata, ) @@ -152,7 +153,7 @@ async def get_image_metadata( @images_router.get( - "/{image_name}/full", + "/i/{image_name}/full", operation_id="get_image_full", response_class=Response, responses={ @@ -187,7 +188,7 @@ async def get_image_full( @images_router.get( - "/{image_name}/thumbnail", + "/i/{image_name}/thumbnail", operation_id="get_image_thumbnail", response_class=Response, responses={ @@ -216,7 +217,7 @@ async def get_image_thumbnail( @images_router.get( - "/{image_name}/urls", + "/i/{image_name}/urls", operation_id="get_image_urls", response_model=ImageUrlsDTO, ) @@ -265,3 +266,24 @@ async def list_image_dtos( ) return image_dtos + + +class DeleteImagesFromListResult(BaseModel): + deleted_images: list[str] + + +@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult) +async def delete_images_from_list( + image_names: list[str] = Body(description="The list of names of images to delete", embed=True), +) -> DeleteImagesFromListResult: + try: + deleted_images: list[str] = [] + for image_name in image_names: + try: + ApiDependencies.invoker.services.images.delete(image_name) + deleted_images.append(image_name) + except: + pass + return DeleteImagesFromListResult(deleted_images=deleted_images) + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to delete images") diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 3588ef4ebe..f91e6cc4c7 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -1,6 +1,6 @@ from typing import Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import Field from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -10,16 +10,17 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -class LoRAMetadataField(BaseModel): +class LoRAMetadataField(BaseModelExcludeNull): """LoRA metadata for an image generated in InvokeAI.""" lora: LoRAModelField = Field(description="The LoRA model") weight: float = Field(description="The weight of the LoRA model") -class CoreMetadata(BaseModel): +class CoreMetadata(BaseModelExcludeNull): """Core generation metadata for an image generated in InvokeAI.""" generation_mode: str = Field( @@ -70,7 +71,7 @@ class CoreMetadata(BaseModel): refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") -class ImageMetadata(BaseModel): +class ImageMetadata(BaseModelExcludeNull): """An image's generation metadata""" metadata: Optional[dict] = Field( diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py index f0007c8cef..03badf9866 100644 --- a/invokeai/app/services/board_image_record_storage.py +++ b/invokeai/app/services/board_image_record_storage.py @@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC): @abstractmethod def remove_image_from_board( self, - board_id: str, image_name: str, ) -> None: """Removes an image from a board.""" @@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): def remove_image_from_board( self, - board_id: str, image_name: str, ) -> None: try: @@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): self._cursor.execute( """--sql DELETE FROM board_images - WHERE board_id = ? AND image_name = ?; + WHERE image_name = ?; """, - (board_id, image_name), + (image_name,), ) self._conn.commit() except sqlite3.Error as e: diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py index 22332d6c29..f41526bfa7 100644 --- a/invokeai/app/services/board_images.py +++ b/invokeai/app/services/board_images.py @@ -31,7 +31,6 @@ class BoardImagesServiceABC(ABC): @abstractmethod def remove_image_from_board( self, - board_id: str, image_name: str, ) -> None: """Removes an image from a board.""" @@ -93,10 +92,9 @@ class BoardImagesService(BoardImagesServiceABC): def remove_image_from_board( self, - board_id: str, image_name: str, ) -> None: - self._services.board_image_records.remove_image_from_board(board_id, image_name) + self._services.board_image_records.remove_image_from_board(image_name) def get_all_board_image_names_for_board( self, diff --git a/invokeai/app/services/models/board_image.py b/invokeai/app/services/models/board_image.py new file mode 100644 index 0000000000..fe585215f3 --- /dev/null +++ b/invokeai/app/services/models/board_image.py @@ -0,0 +1,8 @@ +from pydantic import Field + +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull + + +class BoardImage(BaseModelExcludeNull): + board_id: str = Field(description="The id of the board") + image_name: str = Field(description="The name of the image") diff --git a/invokeai/app/services/models/board_record.py b/invokeai/app/services/models/board_record.py index 658698e794..53fa299faf 100644 --- a/invokeai/app/services/models/board_record.py +++ b/invokeai/app/services/models/board_record.py @@ -1,10 +1,11 @@ from typing import Optional, Union from datetime import datetime -from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr +from pydantic import Field from invokeai.app.util.misc import get_iso_timestamp +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -class BoardRecord(BaseModel): +class BoardRecord(BaseModelExcludeNull): """Deserialized board record.""" board_id: str = Field(description="The unique ID of the board.") diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index a105d03ba8..294b760630 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,13 +1,14 @@ import datetime from typing import Optional, Union -from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr +from pydantic import Extra, Field, StrictBool, StrictStr from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.util.misc import get_iso_timestamp +from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -class ImageRecord(BaseModel): +class ImageRecord(BaseModelExcludeNull): """Deserialized image record without metadata.""" image_name: str = Field(description="The unique name of the image.") @@ -40,7 +41,7 @@ class ImageRecord(BaseModel): """The node ID that generated this image, if it is a generated image.""" -class ImageRecordChanges(BaseModel, extra=Extra.forbid): +class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid): """A set of changes to apply to an image record. Only limited changes are valid: @@ -60,7 +61,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): """The image's new `is_intermediate` flag.""" -class ImageUrlsDTO(BaseModel): +class ImageUrlsDTO(BaseModelExcludeNull): """The URLs for an image and its thumbnail.""" image_name: str = Field(description="The unique name of the image.") @@ -76,11 +77,15 @@ class ImageDTO(ImageRecord, ImageUrlsDTO): board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.") """The id of the board the image belongs to, if one exists.""" + pass def image_record_to_dto( - image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str] + image_record: ImageRecord, + image_url: str, + thumbnail_url: str, + board_id: Optional[str], ) -> ImageDTO: """Converts an image record to an image DTO.""" return ImageDTO( diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 73d8ddadf4..7688b3bdd3 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase): # These paths are determined by the routes in invokeai/app/api/routers/images.py if thumbnail: - return f"{self._base_url}/images/{image_basename}/thumbnail" + return f"{self._base_url}/images/i/{image_basename}/thumbnail" - return f"{self._base_url}/images/{image_basename}/full" + return f"{self._base_url}/images/i/{image_basename}/full" diff --git a/invokeai/app/util/model_exclude_null.py b/invokeai/app/util/model_exclude_null.py new file mode 100644 index 0000000000..d864b8fab8 --- /dev/null +++ b/invokeai/app/util/model_exclude_null.py @@ -0,0 +1,23 @@ +from typing import Any +from pydantic import BaseModel + + +""" +We want to exclude null values from objects that make their way to the client. + +Unfortunately there is no built-in way to do this in pydantic, so we need to override the default +dict method to do this. + +From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541 +""" + + +class BaseModelExcludeNull(BaseModel): + def dict(self, *args, **kwargs) -> dict[str, Any]: + """ + Override the default dict method to exclude None values in the response + """ + kwargs.pop("exclude_none", None) + return super().dict(*args, exclude_none=True, **kwargs) + + pass diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index a76c2ecc02..8cc2c158be 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -23,7 +23,7 @@ "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"", "build": "yarn run lint && vite build", - "typegen": "npx ts-node scripts/typegen.ts", + "typegen": "node scripts/typegen.js", "preview": "vite preview", "lint:madge": "madge --circular src/main.tsx", "lint:eslint": "eslint --max-warnings=0 .", diff --git a/invokeai/frontend/web/scripts/typegen.ts b/invokeai/frontend/web/scripts/typegen.js similarity index 100% rename from invokeai/frontend/web/scripts/typegen.ts rename to invokeai/frontend/web/scripts/typegen.js diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 963d285f72..fa45ae93cd 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -4,8 +4,9 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { PartialAppConfig } from 'app/types/invokeai'; import ImageUploader from 'common/components/ImageUploader'; +import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; +import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal'; import GalleryDrawer from 'features/gallery/components/GalleryPanel'; -import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal'; import SiteHeader from 'features/system/components/SiteHeader'; import { configChanged } from 'features/system/store/configSlice'; import { languageSelector } from 'features/system/store/systemSelectors'; @@ -16,7 +17,6 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import i18n from 'i18n'; import { size } from 'lodash-es'; import { ReactNode, memo, useEffect } from 'react'; -import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; @@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { - + diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx index 82526900ad..c97778ffcd 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx @@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => { ); } - if (props.dragData.payloadType === 'IMAGE_NAMES') { + if (props.dragData.payloadType === 'IMAGE_DTOS') { return ( { ...STYLES, }} > - {props.dragData.payload.image_names.length} + {props.dragData.payload.imageDTOs.length} Images ); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx index 24bdceac3a..56eeb9b5db 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx @@ -18,27 +18,32 @@ import { DragStartEvent, TypesafeDraggableData, } from './typesafeDnd'; +import { logger } from 'app/logging/logger'; type ImageDndContextProps = PropsWithChildren; const ImageDndContext = (props: ImageDndContextProps) => { const [activeDragData, setActiveDragData] = useState(null); + const log = logger('images'); const dispatch = useAppDispatch(); - const handleDragStart = useCallback((event: DragStartEvent) => { - console.log('dragStart', event.active.data.current); - const activeData = event.active.data.current; - if (!activeData) { - return; - } - setActiveDragData(activeData); - }, []); + const handleDragStart = useCallback( + (event: DragStartEvent) => { + log.trace({ dragData: event.active.data.current }, 'Drag started'); + const activeData = event.active.data.current; + if (!activeData) { + return; + } + setActiveDragData(activeData); + }, + [log] + ); const handleDragEnd = useCallback( (event: DragEndEvent) => { - console.log('dragEnd', event.active.data.current); + log.trace({ dragData: event.active.data.current }, 'Drag ended'); const overData = event.over?.data.current; if (!activeDragData || !overData) { return; @@ -46,7 +51,7 @@ const ImageDndContext = (props: ImageDndContextProps) => { dispatch(dndDropped({ overData, activeData: activeDragData })); setActiveDragData(null); }, - [activeDragData, dispatch] + [activeDragData, dispatch, log] ); const mouseSensor = useSensor(MouseSensor, { diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx index 5f08466710..6f24302070 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx @@ -11,7 +11,6 @@ import { useDraggable as useOriginalDraggable, useDroppable as useOriginalDroppable, } from '@dnd-kit/core'; -import { BoardId } from 'features/gallery/store/types'; import { ImageDTO } from 'services/api/types'; type BaseDropData = { @@ -54,9 +53,13 @@ export type AddToBatchDropData = BaseDropData & { actionType: 'ADD_TO_BATCH'; }; -export type MoveBoardDropData = BaseDropData & { - actionType: 'MOVE_BOARD'; - context: { boardId: BoardId }; +export type AddToBoardDropData = BaseDropData & { + actionType: 'ADD_TO_BOARD'; + context: { boardId: string }; +}; + +export type RemoveFromBoardDropData = BaseDropData & { + actionType: 'REMOVE_FROM_BOARD'; }; export type TypesafeDroppableData = @@ -67,7 +70,8 @@ export type TypesafeDroppableData = | NodesImageDropData | AddToBatchDropData | NodesMultiImageDropData - | MoveBoardDropData; + | AddToBoardDropData + | RemoveFromBoardDropData; type BaseDragData = { id: string; @@ -78,14 +82,12 @@ export type ImageDraggableData = BaseDragData & { payload: { imageDTO: ImageDTO }; }; -export type ImageNamesDraggableData = BaseDragData & { - payloadType: 'IMAGE_NAMES'; - payload: { image_names: string[] }; +export type ImageDTOsDraggableData = BaseDragData & { + payloadType: 'IMAGE_DTOS'; + payload: { imageDTOs: ImageDTO[] }; }; -export type TypesafeDraggableData = - | ImageDraggableData - | ImageNamesDraggableData; +export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData; interface UseDroppableTypesafeArguments extends Omit { @@ -156,14 +158,39 @@ export const isValidDrop = ( case 'SET_NODES_IMAGE': return payloadType === 'IMAGE_DTO'; case 'SET_MULTI_NODES_IMAGE': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; case 'ADD_TO_BATCH': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; - case 'MOVE_BOARD': { + return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + case 'ADD_TO_BOARD': { // If the board is the same, don't allow the drop // Check the payload types - const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + if (!isPayloadValid) { + return false; + } + + // Check if the image's board is the board we are dragging onto + if (payloadType === 'IMAGE_DTO') { + const { imageDTO } = active.data.current.payload; + const currentBoard = imageDTO.board_id ?? 'none'; + const destinationBoard = overData.context.boardId; + + return currentBoard !== destinationBoard; + } + + if (payloadType === 'IMAGE_DTOS') { + // TODO (multi-select) + return true; + } + + return false; + } + case 'REMOVE_FROM_BOARD': { + // If the board is the same, don't allow the drop + + // Check the payload types + const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; if (!isPayloadValid) { return false; } @@ -172,20 +199,16 @@ export const isValidDrop = ( if (payloadType === 'IMAGE_DTO') { const { imageDTO } = active.data.current.payload; const currentBoard = imageDTO.board_id; - const destinationBoard = overData.context.boardId; - const isSameBoard = currentBoard === destinationBoard; - const isDestinationValid = !currentBoard ? destinationBoard : true; - - return !isSameBoard && isDestinationValid; + return currentBoard !== 'none'; } - if (payloadType === 'IMAGE_NAMES') { + if (payloadType === 'IMAGE_DTOS') { // TODO (multi-select) - return false; + return true; } - return true; + return false; } default: return false; diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index cffbaa5574..93b7825db7 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -1,4 +1,6 @@ +import { Middleware } from '@reduxjs/toolkit'; import { store } from 'app/store/store'; +import { PartialAppConfig } from 'app/types/invokeai'; import React, { lazy, memo, @@ -7,16 +9,11 @@ import React, { useEffect, } from 'react'; import { Provider } from 'react-redux'; - -import { PartialAppConfig } from 'app/types/invokeai'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; -import Loading from '../../common/components/Loading/Loading'; - -import { Middleware } from '@reduxjs/toolkit'; import { $authToken, $baseUrl, $projectId } from 'services/api/client'; import { socketMiddleware } from 'services/events/middleware'; +import Loading from '../../common/components/Loading/Loading'; import '../../i18n'; -import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; import ImageDndContext from './ImageDnd/ImageDndContext'; const App = lazy(() => import('./App')); @@ -84,9 +81,7 @@ const InvokeAIUI = ({ }> - - - + diff --git a/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx deleted file mode 100644 index d5b3b746f1..0000000000 --- a/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx +++ /dev/null @@ -1,91 +0,0 @@ -import { useDisclosure } from '@chakra-ui/react'; -import { PropsWithChildren, createContext, useCallback, useState } from 'react'; -import { ImageDTO } from 'services/api/types'; -import { imagesApi } from 'services/api/endpoints/images'; -import { useAppDispatch } from '../store/storeHooks'; - -export type ImageUsage = { - isInitialImage: boolean; - isCanvasImage: boolean; - isNodesImage: boolean; - isControlNetImage: boolean; -}; - -type AddImageToBoardContextValue = { - /** - * Whether the move image dialog is open. - */ - isOpen: boolean; - /** - * Closes the move image dialog. - */ - onClose: () => void; - /** - * The image pending movement - */ - image?: ImageDTO; - onClickAddToBoard: (image: ImageDTO) => void; - handleAddToBoard: (boardId: string) => void; -}; - -export const AddImageToBoardContext = - createContext({ - isOpen: false, - onClose: () => undefined, - onClickAddToBoard: () => undefined, - handleAddToBoard: () => undefined, - }); - -type Props = PropsWithChildren; - -export const AddImageToBoardContextProvider = (props: Props) => { - const [imageToMove, setImageToMove] = useState(); - const { isOpen, onOpen, onClose } = useDisclosure(); - const dispatch = useAppDispatch(); - - // Clean up after deleting or dismissing the modal - const closeAndClearImageToDelete = useCallback(() => { - setImageToMove(undefined); - onClose(); - }, [onClose]); - - const onClickAddToBoard = useCallback( - (image?: ImageDTO) => { - if (!image) { - return; - } - setImageToMove(image); - onOpen(); - }, - [setImageToMove, onOpen] - ); - - const handleAddToBoard = useCallback( - (boardId: string) => { - if (imageToMove) { - dispatch( - imagesApi.endpoints.addImageToBoard.initiate({ - imageDTO: imageToMove, - board_id: boardId, - }) - ); - closeAndClearImageToDelete(); - } - }, - [dispatch, closeAndClearImageToDelete, imageToMove] - ); - - return ( - - {props.children} - - ); -}; diff --git a/invokeai/frontend/web/src/app/contexts/ImageUploaderTriggerContext.ts b/invokeai/frontend/web/src/app/contexts/ImageUploaderTriggerContext.ts deleted file mode 100644 index 804e124625..0000000000 --- a/invokeai/frontend/web/src/app/contexts/ImageUploaderTriggerContext.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { createContext } from 'react'; - -type VoidFunc = () => void; - -type ImageUploaderTriggerContextType = VoidFunc | null; - -export const ImageUploaderTriggerContext = - createContext(null); diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index 3407b3f7de..1b21770aa0 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -23,6 +23,6 @@ const serializationDenylist: { }; export const serialize: SerializeFunction = (data, key) => { - const result = omit(data, serializationDenylist[key]); + const result = omit(data, serializationDenylist[key] ?? []); return JSON.stringify(result); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index f06c324bc6..c15b072a07 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -27,7 +27,8 @@ import { addImageDeletedFulfilledListener, addImageDeletedPendingListener, addImageDeletedRejectedListener, - addRequestedImageDeletionListener, + addRequestedSingleImageDeletionListener, + addRequestedMultipleImageDeletionListener, } from './listeners/imageDeleted'; import { addImageDroppedListener } from './listeners/imageDropped'; import { @@ -111,7 +112,8 @@ addImageUploadedRejectedListener(); addInitialImageSelectedListener(); // Image deleted -addRequestedImageDeletionListener(); +addRequestedSingleImageDeletionListener(); +addRequestedMultipleImageDeletionListener(); addImageDeletedPendingListener(); addImageDeletedFulfilledListener(); addImageDeletedRejectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addFirstListImagesListener.ts.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addFirstListImagesListener.ts.ts index ee12f39a12..15e7d48708 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addFirstListImagesListener.ts.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addFirstListImagesListener.ts.ts @@ -1,12 +1,10 @@ import { createAction } from '@reduxjs/toolkit'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; -import { - ImageCache, - getListImagesUrl, - imagesApi, -} from 'services/api/endpoints/images'; +import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '..'; +import { getListImagesUrl, imagesAdapter } from 'services/api/util'; +import { ImageCache } from 'services/api/types'; export const appStarted = createAction('app/appStarted'); @@ -34,7 +32,8 @@ export const addFirstListImagesListener = () => { if (data.ids.length > 0) { // Select the first image - dispatch(imageSelected(data.ids[0] as string)); + const firstImage = imagesAdapter.getSelectors().selectAll(data)[0]; + dispatch(imageSelected(firstImage ?? null)); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts index 2d0ece3595..700b4e7626 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/appConfigReceived.ts @@ -18,7 +18,9 @@ export const addAppConfigReceivedListener = () => { const infillMethod = getState().generation.infillMethod; if (!infill_methods.includes(infillMethod)) { - dispatch(setInfillMethod(infill_methods[0])); + // if there is no infill method, set it to the first one + // if there is no first one... god help us + dispatch(setInfillMethod(infill_methods[0] as string)); } if (!nsfw_methods.includes('nsfw_checker')) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted.ts index f0af52ced6..d4a36d64dc 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted.ts @@ -1,14 +1,14 @@ import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; -import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors'; +import { getImageUsage } from 'features/deleteImageModal/store/selectors'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '..'; -import { boardsApi } from '../../../../../services/api/endpoints/boards'; export const addDeleteBoardAndImagesFulfilledListener = () => { startAppListening({ - matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled, + matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled, effect: async (action, { dispatch, getState }) => { const { deleted_images } = action.payload; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts index f9c856d6cb..1b13181911 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts @@ -10,6 +10,7 @@ import { } from 'features/gallery/store/types'; import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '..'; +import { imagesSelectors } from 'services/api/util'; export const addBoardIdSelectedListener = () => { startAppListening({ @@ -52,8 +53,9 @@ export const addBoardIdSelectedListener = () => { queryArgs )(getState()); - if (boardImagesData?.ids.length) { - dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null)); + if (boardImagesData) { + const firstImage = imagesSelectors.selectAll(boardImagesData)[0]; + dispatch(imageSelected(firstImage ?? null)); } else { // board has no images - deselect dispatch(imageSelected(null)); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index 47f7aded27..dbadb72a52 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -26,6 +26,8 @@ export const addCanvasSavedToGalleryListener = () => { return; } + const { autoAddBoardId } = state.gallery; + dispatch( imagesApi.endpoints.uploadImage.initiate({ file: new File([blob], 'savedCanvas.png', { @@ -33,7 +35,7 @@ export const addCanvasSavedToGalleryListener = () => { }), image_category: 'general', is_intermediate: false, - board_id: state.gallery.autoAddBoardId, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, crop_visible: true, postUploadAction: { type: 'TOAST', diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts index 4a47e8d64e..61bcf28833 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts @@ -31,15 +31,20 @@ const predicate: AnyListenerPredicate = ( // do not process if the user just disabled auto-config if ( prevState.controlNet.controlNets[action.payload.controlNetId] - .shouldAutoConfig === true + ?.shouldAutoConfig === true ) { return false; } } - const { controlImage, processorType, shouldAutoConfig } = - state.controlNet.controlNets[action.payload.controlNetId]; + const cn = state.controlNet.controlNets[action.payload.controlNetId]; + if (!cn) { + // something is wrong, the controlNet should exist + return false; + } + + const { controlImage, processorType, shouldAutoConfig } = cn; if (controlNetModelChanged.match(action) && !shouldAutoConfig) { // do not process if the action is a model change but the processor settings are dirty return false; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 313b2a02d8..fa915ef21b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => { const { controlNetId } = action.payload; const controlNet = getState().controlNet.controlNets[controlNetId]; - if (!controlNet.controlImage) { + if (!controlNet?.controlImage) { log.error('Unable to process ControlNet image'); return; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 428ce53219..cdfae0095e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -1,57 +1,72 @@ import { logger } from 'app/logging/logger'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; +import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; +import { isModalOpenChanged } from 'features/deleteImageModal/store/slice'; import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions'; -import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clamp } from 'lodash-es'; import { api } from 'services/api'; import { imagesApi } from 'services/api/endpoints/images'; +import { imagesAdapter } from 'services/api/util'; import { startAppListening } from '..'; -/** - * Called when the user requests an image deletion - */ -export const addRequestedImageDeletionListener = () => { +export const addRequestedSingleImageDeletionListener = () => { startAppListening({ actionCreator: imageDeletionConfirmed, effect: async (action, { dispatch, getState, condition }) => { - const { imageDTO, imageUsage } = action.payload; + const { imageDTOs, imagesUsage } = action.payload; + + if (imageDTOs.length !== 1 || imagesUsage.length !== 1) { + // handle multiples in separate listener + return; + } + + const imageDTO = imageDTOs[0]; + const imageUsage = imagesUsage[0]; + + if (!imageDTO || !imageUsage) { + // satisfy noUncheckedIndexedAccess + return; + } dispatch(isModalOpenChanged(false)); - const { image_name } = imageDTO; - const state = getState(); const lastSelectedImage = - state.gallery.selection[state.gallery.selection.length - 1]; + state.gallery.selection[state.gallery.selection.length - 1]?.image_name; + + if (imageDTO && imageDTO?.image_name === lastSelectedImage) { + const { image_name } = imageDTO; - if (lastSelectedImage === image_name) { const baseQueryArgs = selectListImagesBaseQueryArgs(state); const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state); - const ids = data?.ids ?? []; + const cachedImageDTOs = data + ? imagesAdapter.getSelectors().selectAll(data) + : []; - const deletedImageIndex = ids.findIndex( - (result) => result.toString() === image_name + const deletedImageIndex = cachedImageDTOs.findIndex( + (i) => i.image_name === image_name ); - const filteredIds = ids.filter((id) => id.toString() !== image_name); + const filteredImageDTOs = cachedImageDTOs.filter( + (i) => i.image_name !== image_name + ); const newSelectedImageIndex = clamp( deletedImageIndex, 0, - filteredIds.length - 1 + filteredImageDTOs.length - 1 ); - const newSelectedImageId = filteredIds[newSelectedImageIndex]; + const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex]; - if (newSelectedImageId) { - dispatch(imageSelected(newSelectedImageId as string)); + if (newSelectedImageDTO) { + dispatch(imageSelected(newSelectedImageDTO)); } else { dispatch(imageSelected(null)); } @@ -97,6 +112,66 @@ export const addRequestedImageDeletionListener = () => { }); }; +/** + * Called when the user requests an image deletion + */ +export const addRequestedMultipleImageDeletionListener = () => { + startAppListening({ + actionCreator: imageDeletionConfirmed, + effect: async (action, { dispatch, getState }) => { + const { imageDTOs, imagesUsage } = action.payload; + + if (imageDTOs.length < 1 || imagesUsage.length < 1) { + // handle singles in separate listener + return; + } + + try { + // Delete from server + await dispatch( + imagesApi.endpoints.deleteImages.initiate({ imageDTOs }) + ).unwrap(); + const state = getState(); + const baseQueryArgs = selectListImagesBaseQueryArgs(state); + const { data } = + imagesApi.endpoints.listImages.select(baseQueryArgs)(state); + + const newSelectedImageDTO = data + ? imagesAdapter.getSelectors().selectAll(data)[0] + : undefined; + + if (newSelectedImageDTO) { + dispatch(imageSelected(newSelectedImageDTO)); + } else { + dispatch(imageSelected(null)); + } + + dispatch(isModalOpenChanged(false)); + + // We need to reset the features where the image is in use - none of these work if their image(s) don't exist + + if (imagesUsage.some((i) => i.isCanvasImage)) { + dispatch(resetCanvas()); + } + + if (imagesUsage.some((i) => i.isControlNetImage)) { + dispatch(controlNetReset()); + } + + if (imagesUsage.some((i) => i.isInitialImage)) { + dispatch(clearInitialImage()); + } + + if (imagesUsage.some((i) => i.isNodesImage)) { + dispatch(nodeEditorReset()); + } + } catch { + // no-op + } + }, + }); +}; + /** * Called when the actual delete request is sent to the server */ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index fdf0849a12..043105cb66 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -6,10 +6,7 @@ import { import { logger } from 'app/logging/logger'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; -import { - imageSelected, - imagesAddedToBatch, -} from 'features/gallery/store/gallerySlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { imagesApi } from 'services/api/endpoints/images'; @@ -27,19 +24,32 @@ export const addImageDroppedListener = () => { const log = logger('images'); const { activeData, overData } = action.payload; - log.debug({ activeData, overData }, 'Image or selection dropped'); + if (activeData.payloadType === 'IMAGE_DTO') { + log.debug({ activeData, overData }, 'Image dropped'); + } else if (activeData.payloadType === 'IMAGE_DTOS') { + log.debug( + { activeData, overData }, + `Images (${activeData.payload.imageDTOs.length}) dropped` + ); + } else { + log.debug({ activeData, overData }, `Unknown payload dropped`); + } - // set current image + /** + * Image dropped on current image + */ if ( overData.actionType === 'SET_CURRENT_IMAGE' && activeData.payloadType === 'IMAGE_DTO' && activeData.payload.imageDTO ) { - dispatch(imageSelected(activeData.payload.imageDTO.image_name)); + dispatch(imageSelected(activeData.payload.imageDTO)); return; } - // set initial image + /** + * Image dropped on initial image + */ if ( overData.actionType === 'SET_INITIAL_IMAGE' && activeData.payloadType === 'IMAGE_DTO' && @@ -49,27 +59,9 @@ export const addImageDroppedListener = () => { return; } - // add image to batch - if ( - overData.actionType === 'ADD_TO_BATCH' && - activeData.payloadType === 'IMAGE_DTO' && - activeData.payload.imageDTO - ) { - dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name])); - return; - } - - // add multiple images to batch - if ( - overData.actionType === 'ADD_TO_BATCH' && - activeData.payloadType === 'IMAGE_NAMES' - ) { - dispatch(imagesAddedToBatch(activeData.payload.image_names)); - - return; - } - - // set control image + /** + * Image dropped on ControlNet + */ if ( overData.actionType === 'SET_CONTROLNET_IMAGE' && activeData.payloadType === 'IMAGE_DTO' && @@ -85,7 +77,9 @@ export const addImageDroppedListener = () => { return; } - // set canvas image + /** + * Image dropped on Canvas + */ if ( overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' && activeData.payloadType === 'IMAGE_DTO' && @@ -95,7 +89,9 @@ export const addImageDroppedListener = () => { return; } - // set nodes image + /** + * Image dropped on node image field + */ if ( overData.actionType === 'SET_NODES_IMAGE' && activeData.payloadType === 'IMAGE_DTO' && @@ -112,61 +108,36 @@ export const addImageDroppedListener = () => { return; } - // set multiple nodes images (single image handler) - if ( - overData.actionType === 'SET_MULTI_NODES_IMAGE' && - activeData.payloadType === 'IMAGE_DTO' && - activeData.payload.imageDTO - ) { - const { fieldName, nodeId } = overData.context; - dispatch( - fieldValueChanged({ - nodeId, - fieldName, - value: [activeData.payload.imageDTO], - }) - ); - return; - } - - // // set multiple nodes images (multiple images handler) + /** + * TODO + * Image selection dropped on node image collection field + */ // if ( // overData.actionType === 'SET_MULTI_NODES_IMAGE' && - // activeData.payloadType === 'IMAGE_NAMES' + // activeData.payloadType === 'IMAGE_DTO' && + // activeData.payload.imageDTO // ) { // const { fieldName, nodeId } = overData.context; // dispatch( - // imageCollectionFieldValueChanged({ + // fieldValueChanged({ // nodeId, // fieldName, - // value: activeData.payload.image_names.map((image_name) => ({ - // image_name, - // })), + // value: [activeData.payload.imageDTO], // }) // ); // return; // } - // add image to board + /** + * Image dropped on user board + */ if ( - overData.actionType === 'MOVE_BOARD' && + overData.actionType === 'ADD_TO_BOARD' && activeData.payloadType === 'IMAGE_DTO' && activeData.payload.imageDTO ) { const { imageDTO } = activeData.payload; const { boardId } = overData.context; - - // image was droppe on the "NoBoardBoard" - if (!boardId) { - dispatch( - imagesApi.endpoints.removeImageFromBoard.initiate({ - imageDTO, - }) - ); - return; - } - - // image was dropped on a user board dispatch( imagesApi.endpoints.addImageToBoard.initiate({ imageDTO, @@ -176,67 +147,58 @@ export const addImageDroppedListener = () => { return; } - // // add gallery selection to board - // if ( - // overData.actionType === 'MOVE_BOARD' && - // activeData.payloadType === 'IMAGE_NAMES' && - // overData.context.boardId - // ) { - // console.log('adding gallery selection to board'); - // const board_id = overData.context.boardId; - // dispatch( - // boardImagesApi.endpoints.addManyBoardImages.initiate({ - // board_id, - // image_names: activeData.payload.image_names, - // }) - // ); - // return; - // } + /** + * Image dropped on 'none' board + */ + if ( + overData.actionType === 'REMOVE_FROM_BOARD' && + activeData.payloadType === 'IMAGE_DTO' && + activeData.payload.imageDTO + ) { + const { imageDTO } = activeData.payload; + dispatch( + imagesApi.endpoints.removeImageFromBoard.initiate({ + imageDTO, + }) + ); + return; + } - // // remove gallery selection from board - // if ( - // overData.actionType === 'MOVE_BOARD' && - // activeData.payloadType === 'IMAGE_NAMES' && - // overData.context.boardId === null - // ) { - // console.log('removing gallery selection to board'); - // dispatch( - // boardImagesApi.endpoints.deleteManyBoardImages.initiate({ - // image_names: activeData.payload.image_names, - // }) - // ); - // return; - // } + /** + * Multiple images dropped on user board + */ + if ( + overData.actionType === 'ADD_TO_BOARD' && + activeData.payloadType === 'IMAGE_DTOS' && + activeData.payload.imageDTOs + ) { + const { imageDTOs } = activeData.payload; + const { boardId } = overData.context; + dispatch( + imagesApi.endpoints.addImagesToBoard.initiate({ + imageDTOs, + board_id: boardId, + }) + ); + return; + } - // // add batch selection to board - // if ( - // overData.actionType === 'MOVE_BOARD' && - // activeData.payloadType === 'IMAGE_NAMES' && - // overData.context.boardId - // ) { - // const board_id = overData.context.boardId; - // dispatch( - // boardImagesApi.endpoints.addManyBoardImages.initiate({ - // board_id, - // image_names: activeData.payload.image_names, - // }) - // ); - // return; - // } - - // // remove batch selection from board - // if ( - // overData.actionType === 'MOVE_BOARD' && - // activeData.payloadType === 'IMAGE_NAMES' && - // overData.context.boardId === null - // ) { - // dispatch( - // boardImagesApi.endpoints.deleteManyBoardImages.initiate({ - // image_names: activeData.payload.image_names, - // }) - // ); - // return; - // } + /** + * Multiple images dropped on 'none' board + */ + if ( + overData.actionType === 'REMOVE_FROM_BOARD' && + activeData.payloadType === 'IMAGE_DTOS' && + activeData.payload.imageDTOs + ) { + const { imageDTOs } = activeData.payload; + dispatch( + imagesApi.endpoints.removeImagesFromBoard.initiate({ + imageDTOs, + }) + ); + return; + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts index 3a5eed95db..88a4e773d5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected.ts @@ -1,37 +1,32 @@ -import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions'; -import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors'; +import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; +import { selectImageUsage } from 'features/deleteImageModal/store/selectors'; import { - imageToDeleteSelected, + imagesToDeleteSelected, isModalOpenChanged, -} from 'features/imageDeletion/store/imageDeletionSlice'; +} from 'features/deleteImageModal/store/slice'; import { startAppListening } from '..'; export const addImageToDeleteSelectedListener = () => { startAppListening({ - actionCreator: imageToDeleteSelected, + actionCreator: imagesToDeleteSelected, effect: async (action, { dispatch, getState }) => { - const imageDTO = action.payload; + const imageDTOs = action.payload; const state = getState(); const { shouldConfirmOnDelete } = state.system; - const imageUsage = selectImageUsage(getState()); - - if (!imageUsage) { - // should never happen - return; - } + const imagesUsage = selectImageUsage(getState()); const isImageInUse = - imageUsage.isCanvasImage || - imageUsage.isInitialImage || - imageUsage.isControlNetImage || - imageUsage.isNodesImage; + imagesUsage.some((i) => i.isCanvasImage) || + imagesUsage.some((i) => i.isInitialImage) || + imagesUsage.some((i) => i.isControlNetImage) || + imagesUsage.some((i) => i.isNodesImage); if (shouldConfirmOnDelete || isImageInUse) { dispatch(isModalOpenChanged(true)); return; } - dispatch(imageDeletionConfirmed({ imageDTO, imageUsage })); + dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage })); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index dd581d893c..f488259eb7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -2,14 +2,13 @@ import { UseToastOptions } from '@chakra-ui/react'; import { logger } from 'app/logging/logger'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; -import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { addToast } from 'features/system/store/systemSlice'; +import { omit } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; import { startAppListening } from '..'; import { imagesApi } from '../../../../../services/api/endpoints/images'; -import { omit } from 'lodash-es'; const DEFAULT_UPLOADED_TOAST: UseToastOptions = { title: 'Image Uploaded', @@ -121,17 +120,6 @@ export const addImageUploadedFulfilledListener = () => { ); return; } - - if (postUploadAction?.type === 'ADD_TO_BATCH') { - dispatch(imagesAddedToBatch([imageDTO.image_name])); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: 'Added to batch', - }) - ); - return; - } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 325e843900..436a58aa8e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -15,7 +15,7 @@ import { setShouldUseSDXLRefiner, } from 'features/sdxl/store/sdxlSlice'; import { forEach, some } from 'lodash-es'; -import { modelsApi } from 'services/api/endpoints/models'; +import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models'; import { startAppListening } from '..'; export const addModelsLoadedListener = () => { @@ -144,8 +144,9 @@ export const addModelsLoadedListener = () => { return; } - const firstModelId = action.payload.ids[0]; - const firstModel = action.payload.entities[firstModelId]; + const firstModel = vaeModelsAdapter + .getSelectors() + .selectAll(action.payload)[0]; if (!firstModel) { // No custom VAEs loaded at all; use the default diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index e36c49be63..30e0bedb54 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -8,9 +8,10 @@ import { } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { progressImageSet } from 'features/system/store/systemSlice'; -import { imagesAdapter, imagesApi } from 'services/api/endpoints/images'; +import { imagesApi } from 'services/api/endpoints/images'; import { isImageOutput } from 'services/api/guards'; import { sessionCanceled } from 'services/api/thunks/session'; +import { imagesAdapter } from 'services/api/util'; import { appSocketInvocationComplete, socketInvocationComplete, @@ -67,7 +68,7 @@ export const addInvocationCompleteEventListener = () => { */ const { autoAddBoardId } = gallery; - if (autoAddBoardId) { + if (autoAddBoardId && autoAddBoardId !== 'none') { dispatch( imagesApi.endpoints.addImageToBoard.initiate({ board_id: autoAddBoardId, @@ -83,10 +84,7 @@ export const addInvocationCompleteEventListener = () => { categories: IMAGE_CATEGORIES, }, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.addOne(draft, imageDTO); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.addOne(draft, imageDTO); } ) ); @@ -94,8 +92,8 @@ export const addInvocationCompleteEventListener = () => { dispatch( imagesApi.util.invalidateTags([ - { type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' }, - { type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' }, + { type: 'BoardImagesTotal', id: autoAddBoardId }, + { type: 'BoardAssetsTotal', id: autoAddBoardId }, ]) ); @@ -110,7 +108,7 @@ export const addInvocationCompleteEventListener = () => { } else if (!autoAddBoardId) { dispatch(galleryViewChanged('images')); } - dispatch(imageSelected(imageDTO.image_name)); + dispatch(imageSelected(imageDTO)); } } diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index d71a147913..6b544252db 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -8,9 +8,9 @@ import { import canvasReducer from 'features/canvas/store/canvasSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice'; -import boardsReducer from 'features/gallery/store/boardSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; -import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; +import deleteImageModalReducer from 'features/deleteImageModal/store/slice'; +import changeBoardModalReducer from 'features/changeBoardModal/store/slice'; import loraReducer from 'features/lora/store/loraSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; @@ -43,9 +43,9 @@ const allReducers = { ui: uiReducer, hotkeys: hotkeysReducer, controlNet: controlNetReducer, - boards: boardsReducer, dynamicPrompts: dynamicPromptsReducer, - imageDeletion: imageDeletionReducer, + deleteImageModal: deleteImageModalReducer, + changeBoardModal: changeBoardModalReducer, lora: loraReducer, modelmanager: modelmanagerReducer, sdxl: sdxlReducer, diff --git a/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx b/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx index 7601758409..f9bb36cc50 100644 --- a/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDropOverlay.tsx @@ -1,4 +1,4 @@ -import { Flex, Text, useColorMode } from '@chakra-ui/react'; +import { Box, Flex, useColorMode } from '@chakra-ui/react'; import { motion } from 'framer-motion'; import { ReactNode, memo, useRef } from 'react'; import { mode } from 'theme/util/mode'; @@ -74,7 +74,7 @@ export const IAIDropOverlay = (props: Props) => { justifyContent: 'center', }} > - { }} > {label} - + diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx index 2c3f5434ad..079421d4e5 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx @@ -53,7 +53,9 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => { // wrap onChange to clear search value on select const handleChange = useCallback( (v: string | null) => { - setSearchValue(''); + // cannot figure out why we were doing this, but it was causing an issue where if you + // select the currently-selected item, it reset the search value to empty + // setSearchValue(''); if (!onChange) { return; diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx index de347b8381..c990a9a24e 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx @@ -78,7 +78,7 @@ const ImageUploader = (props: ImageUploaderProps) => { image_category: 'user', is_intermediate: false, postUploadAction, - board_id: autoAddBoardId, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, }); }, [autoAddBoardId, postUploadAction, uploadImage] diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx index c04c0182cd..dcbd81b2dd 100644 --- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx +++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx @@ -49,7 +49,7 @@ export const useImageUploadButton = ({ image_category: 'user', is_intermediate: false, postUploadAction: postUploadAction ?? { type: 'TOAST' }, - board_id: autoAddBoardId, + board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, }); }, [autoAddBoardId, postUploadAction, uploadImage] diff --git a/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts b/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts index 1356b24416..64289a1fd3 100644 --- a/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts +++ b/invokeai/frontend/web/src/features/canvas/hooks/useColorUnderCursor.ts @@ -33,6 +33,10 @@ const useColorPicker = () => { 1 ).data; + if (!(a && r && g && b)) { + return; + } + dispatch(setColorPickerColor({ r, g, b, a })); }, commitColorUnderCursor: () => { diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index 3163e513e9..f63ab2fd67 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -727,10 +727,13 @@ export const canvasSlice = createSlice({ state.pastLayerStates.shift(); } - state.layerState.objects.push({ - ...images[selectedImageIndex], - }); + const imageToCommit = images[selectedImageIndex]; + if (imageToCommit) { + state.layerState.objects.push({ + ...imageToCommit, + }); + } state.layerState.stagingArea = { ...initialLayerState.stagingArea, }; diff --git a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx new file mode 100644 index 0000000000..2443fa6081 --- /dev/null +++ b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx @@ -0,0 +1,132 @@ +import { + AlertDialog, + AlertDialogBody, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogOverlay, + Flex, + Text, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; +import { memo, useCallback, useMemo, useRef, useState } from 'react'; +import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; +import { + useAddImagesToBoardMutation, + useRemoveImagesFromBoardMutation, +} from 'services/api/endpoints/images'; +import { changeBoardReset, isModalOpenChanged } from '../store/slice'; + +const selector = createSelector( + [stateSelector], + ({ changeBoardModal }) => { + const { isModalOpen, imagesToChange } = changeBoardModal; + + return { + isModalOpen, + imagesToChange, + }; + }, + defaultSelectorOptions +); + +const ChangeBoardModal = () => { + const dispatch = useAppDispatch(); + const [selectedBoard, setSelectedBoard] = useState(); + const { data: boards, isFetching } = useListAllBoardsQuery(); + const { imagesToChange, isModalOpen } = useAppSelector(selector); + const [addImagesToBoard] = useAddImagesToBoardMutation(); + const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation(); + + const data = useMemo(() => { + const data: { label: string; value: string }[] = [ + { label: 'Uncategorized', value: 'none' }, + ]; + (boards ?? []).forEach((board) => + data.push({ + label: board.board_name, + value: board.board_id, + }) + ); + + return data; + }, [boards]); + + const handleClose = useCallback(() => { + dispatch(changeBoardReset()); + dispatch(isModalOpenChanged(false)); + }, [dispatch]); + + const handleChangeBoard = useCallback(() => { + if (!imagesToChange.length || !selectedBoard) { + return; + } + + if (selectedBoard === 'none') { + removeImagesFromBoard({ imageDTOs: imagesToChange }); + } else { + addImagesToBoard({ + imageDTOs: imagesToChange, + board_id: selectedBoard, + }); + } + setSelectedBoard(null); + dispatch(changeBoardReset()); + }, [ + addImagesToBoard, + dispatch, + imagesToChange, + removeImagesFromBoard, + selectedBoard, + ]); + + const cancelRef = useRef(null); + + return ( + + + + + Change Board + + + + + + Moving {`${imagesToChange.length}`} image + {`${imagesToChange.length > 1 ? 's' : ''}`} to board: + + setSelectedBoard(v)} + value={selectedBoard} + data={data} + /> + + + + + Cancel + + + Move + + + + + + ); +}; + +export default memo(ChangeBoardModal); diff --git a/invokeai/frontend/web/src/features/changeBoardModal/store/initialState.ts b/invokeai/frontend/web/src/features/changeBoardModal/store/initialState.ts new file mode 100644 index 0000000000..d737d0cdcd --- /dev/null +++ b/invokeai/frontend/web/src/features/changeBoardModal/store/initialState.ts @@ -0,0 +1,6 @@ +import { ChangeBoardModalState } from './types'; + +export const initialState: ChangeBoardModalState = { + isModalOpen: false, + imagesToChange: [], +}; diff --git a/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts b/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts new file mode 100644 index 0000000000..9855e2d7dd --- /dev/null +++ b/invokeai/frontend/web/src/features/changeBoardModal/store/slice.ts @@ -0,0 +1,25 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { ImageDTO } from 'services/api/types'; +import { initialState } from './initialState'; + +const changeBoardModal = createSlice({ + name: 'changeBoardModal', + initialState, + reducers: { + isModalOpenChanged: (state, action: PayloadAction) => { + state.isModalOpen = action.payload; + }, + imagesToChangeSelected: (state, action: PayloadAction) => { + state.imagesToChange = action.payload; + }, + changeBoardReset: (state) => { + state.imagesToChange = []; + state.isModalOpen = false; + }, + }, +}); + +export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } = + changeBoardModal.actions; + +export default changeBoardModal.reducer; diff --git a/invokeai/frontend/web/src/features/changeBoardModal/store/types.ts b/invokeai/frontend/web/src/features/changeBoardModal/store/types.ts new file mode 100644 index 0000000000..6ce13331d0 --- /dev/null +++ b/invokeai/frontend/web/src/features/changeBoardModal/store/types.ts @@ -0,0 +1,6 @@ +import { ImageDTO } from 'services/api/types'; + +export type ChangeBoardModalState = { + isModalOpen: boolean; + imagesToChange: ImageDTO[]; +}; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx index d858e46fdb..3252207edc 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx @@ -3,6 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { memo, useCallback } from 'react'; import { FaCopy, FaTrash } from 'react-icons/fa'; import { + ControlNetConfig, controlNetDuplicated, controlNetRemoved, controlNetToggled, @@ -27,18 +28,27 @@ import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcesso import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode'; type ControlNetProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; const ControlNet = (props: ControlNetProps) => { - const { controlNetId } = props; + const { controlNet } = props; + const { controlNetId } = controlNet; const dispatch = useAppDispatch(); const selector = createSelector( stateSelector, ({ controlNet }) => { - const { isEnabled, shouldAutoConfig } = - controlNet.controlNets[controlNetId]; + const cn = controlNet.controlNets[controlNetId]; + + if (!cn) { + return { + isEnabled: false, + shouldAutoConfig: false, + }; + } + + const { isEnabled, shouldAutoConfig } = cn; return { isEnabled, shouldAutoConfig }; }, @@ -96,7 +106,7 @@ const ControlNet = (props: ControlNetProps) => { transitionDuration: '0.1s', }} > - + { justifyContent: 'space-between', }} > - - + + {!isExpanded && ( { aspectRatio: '1/1', }} > - + )} - - + + - + {isExpanded && ( <> - - - + + + )} diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index 859495a941..cdab176cd2 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -12,50 +12,41 @@ import IAIDndImage from 'common/components/IAIDndImage'; import { memo, useCallback, useMemo, useState } from 'react'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; -import { controlNetImageChanged } from '../store/controlNetSlice'; +import { + ControlNetConfig, + controlNetImageChanged, +} from '../store/controlNetSlice'; type Props = { - controlNetId: string; + controlNet: ControlNetConfig; height: SystemStyleObject['h']; }; +const selector = createSelector( + stateSelector, + ({ controlNet }) => { + const { pendingControlImages } = controlNet; + + return { + pendingControlImages, + }; + }, + defaultSelectorOptions +); + const ControlNetImagePreview = (props: Props) => { - const { height, controlNetId } = props; + const { height } = props; + const { + controlImage: controlImageName, + processedControlImage: processedControlImageName, + processorType, + isEnabled, + controlNetId, + } = props.controlNet; + const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { pendingControlImages } = controlNet; - const { - controlImage, - processedControlImage, - processorType, - isEnabled, - } = controlNet.controlNets[controlNetId]; - - return { - controlImageName: controlImage, - processedControlImageName: processedControlImage, - processorType, - isEnabled, - pendingControlImages, - }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { - controlImageName, - processedControlImageName, - processorType, - pendingControlImages, - isEnabled, - } = useAppSelector(selector); + const { pendingControlImages } = useAppSelector(selector); const [isMouseOverImage, setIsMouseOverImage] = useState(false); diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx index b7fa329eac..681838ef27 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetProcessorComponent.tsx @@ -1,8 +1,5 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { memo, useMemo } from 'react'; +import { memo } from 'react'; +import { ControlNetConfig } from '../store/controlNetSlice'; import CannyProcessor from './processors/CannyProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; import HedProcessor from './processors/HedProcessor'; @@ -17,28 +14,11 @@ import PidiProcessor from './processors/PidiProcessor'; import ZoeDepthProcessor from './processors/ZoeDepthProcessor'; export type ControlNetProcessorProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { - const { controlNetId } = props; - - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { isEnabled, processorNode } = - controlNet.controlNets[controlNetId]; - - return { isEnabled, processorNode }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { isEnabled, processorNode } = useAppSelector(selector); + const { controlNetId, isEnabled, processorNode } = props.controlNet; if (processorNode.type === 'canny_image_processor') { return ( diff --git a/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx b/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx index 285fcf7b80..0e044d4575 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ParamControlNetShouldAutoConfig.tsx @@ -1,34 +1,19 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; -import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice'; +import { + ControlNetConfig, + controlNetAutoConfigToggled, +} from 'features/controlNet/store/controlNetSlice'; import { selectIsBusy } from 'features/system/store/systemSelectors'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; type Props = { - controlNetId: string; + controlNet: ControlNetConfig; }; const ParamControlNetShouldAutoConfig = (props: Props) => { - const { controlNetId } = props; + const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet; const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { isEnabled, shouldAutoConfig } = - controlNet.controlNets[controlNetId]; - return { isEnabled, shouldAutoConfig }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { isEnabled, shouldAutoConfig } = useAppSelector(selector); const isBusy = useAppSelector(selectIsBusy); const handleShouldAutoConfigChanged = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx index 3dd420e7c9..1219239e5d 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetBeginEnd.tsx @@ -9,48 +9,39 @@ import { RangeSliderTrack, Tooltip, } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useAppDispatch } from 'app/store/storeHooks'; import { + ControlNetConfig, controlNetBeginStepPctChanged, controlNetEndStepPctChanged, } from 'features/controlNet/store/controlNetSlice'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; type Props = { - controlNetId: string; + controlNet: ControlNetConfig; }; const formatPct = (v: number) => `${Math.round(v * 100)}%`; const ParamControlNetBeginEnd = (props: Props) => { - const { controlNetId } = props; + const { beginStepPct, endStepPct, isEnabled, controlNetId } = + props.controlNet; const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { beginStepPct, endStepPct, isEnabled } = - controlNet.controlNets[controlNetId]; - return { beginStepPct, endStepPct, isEnabled }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector); - const handleStepPctChanged = useCallback( (v: number[]) => { dispatch( - controlNetBeginStepPctChanged({ controlNetId, beginStepPct: v[0] }) + controlNetBeginStepPctChanged({ + controlNetId, + beginStepPct: v[0] as number, + }) + ); + dispatch( + controlNetEndStepPctChanged({ + controlNetId, + endStepPct: v[1] as number, + }) ); - dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: v[1] })); }, [controlNetId, dispatch] ); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx index e644e24a02..761edde42b 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetControlMode.tsx @@ -1,16 +1,14 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { ControlModes, + ControlNetConfig, controlNetControlModeChanged, } from 'features/controlNet/store/controlNetSlice'; -import { useCallback, useMemo } from 'react'; +import { useCallback } from 'react'; type ParamControlNetControlModeProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; const CONTROL_MODE_DATA = [ @@ -23,23 +21,8 @@ const CONTROL_MODE_DATA = [ export default function ParamControlNetControlMode( props: ParamControlNetControlModeProps ) { - const { controlNetId } = props; + const { controlMode, isEnabled, controlNetId } = props.controlNet; const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { controlMode, isEnabled } = - controlNet.controlNets[controlNetId]; - return { controlMode, isEnabled }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { controlMode, isEnabled } = useAppSelector(selector); const handleControlModeChange = useCallback( (controlMode: ControlModes) => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index 8392bdd2e3..5d7db854d8 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -5,7 +5,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; -import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; +import { + ControlNetConfig, + controlNetModelChanged, +} from 'features/controlNet/store/controlNetSlice'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; import { selectIsBusy } from 'features/system/store/systemSelectors'; @@ -14,30 +17,24 @@ import { memo, useCallback, useMemo } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; type ParamControlNetModelProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; +const selector = createSelector( + stateSelector, + ({ generation }) => { + const { model } = generation; + return { mainModel: model }; + }, + defaultSelectorOptions +); + const ParamControlNetModel = (props: ParamControlNetModelProps) => { - const { controlNetId } = props; + const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet; const dispatch = useAppDispatch(); const isBusy = useAppSelector(selectIsBusy); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ generation, controlNet }) => { - const { model } = generation; - const controlNetModel = controlNet.controlNets[controlNetId]?.model; - const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled; - return { mainModel: model, controlNetModel, isEnabled }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector); + const { mainModel } = useAppSelector(selector); const { data: controlNetModels } = useGetControlNetModelsQuery(); diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx index 83c66363ac..190b1bc012 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx @@ -1,7 +1,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect, { IAISelectDataType, @@ -9,13 +8,16 @@ import IAIMantineSearchableSelect, { import { configSelector } from 'features/system/store/configSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors'; import { map } from 'lodash-es'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import { CONTROLNET_PROCESSORS } from '../../store/constants'; -import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; +import { + ControlNetConfig, + controlNetProcessorTypeChanged, +} from '../../store/controlNetSlice'; import { ControlNetProcessorType } from '../../store/types'; type ParamControlNetProcessorSelectProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; const selector = createSelector( @@ -52,23 +54,9 @@ const ParamControlNetProcessorSelect = ( props: ParamControlNetProcessorSelectProps ) => { const dispatch = useAppDispatch(); - const { controlNetId } = props; - const processorNodeSelector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { isEnabled, processorNode } = - controlNet.controlNets[controlNetId]; - return { isEnabled, processorNode }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); + const { controlNetId, isEnabled, processorNode } = props.controlNet; const isBusy = useAppSelector(selectIsBusy); const controlNetProcessors = useAppSelector(selector); - const { isEnabled, processorNode } = useAppSelector(processorNodeSelector); const handleProcessorTypeChanged = useCallback( (v: string | null) => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx index ee04b8077f..72f15fb178 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetResizeMode.tsx @@ -1,16 +1,14 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { + ControlNetConfig, ResizeModes, controlNetResizeModeChanged, } from 'features/controlNet/store/controlNetSlice'; -import { useCallback, useMemo } from 'react'; +import { useCallback } from 'react'; type ParamControlNetResizeModeProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; const RESIZE_MODE_DATA = [ @@ -22,23 +20,8 @@ const RESIZE_MODE_DATA = [ export default function ParamControlNetResizeMode( props: ParamControlNetResizeModeProps ) { - const { controlNetId } = props; + const { resizeMode, isEnabled, controlNetId } = props.controlNet; const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { resizeMode, isEnabled } = - controlNet.controlNets[controlNetId]; - return { resizeMode, isEnabled }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { resizeMode, isEnabled } = useAppSelector(selector); const handleResizeModeChange = useCallback( (resizeMode: ResizeModes) => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx index 8643fd7dad..c08283e1f9 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetWeight.tsx @@ -1,32 +1,18 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useAppDispatch } from 'app/store/storeHooks'; import IAISlider from 'common/components/IAISlider'; -import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice'; -import { memo, useCallback, useMemo } from 'react'; +import { + ControlNetConfig, + controlNetWeightChanged, +} from 'features/controlNet/store/controlNetSlice'; +import { memo, useCallback } from 'react'; type ParamControlNetWeightProps = { - controlNetId: string; + controlNet: ControlNetConfig; }; const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { - const { controlNetId } = props; + const { weight, isEnabled, controlNetId } = props.controlNet; const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ controlNet }) => { - const { weight, isEnabled } = controlNet.controlNets[controlNetId]; - return { weight, isEnabled }; - }, - defaultSelectorOptions - ), - [controlNetId] - ); - - const { weight, isEnabled } = useAppSelector(selector); const handleWeightChanged = useCallback( (weight: number) => { dispatch(controlNetWeightChanged({ controlNetId, weight })); diff --git a/invokeai/frontend/web/src/features/controlNet/store/constants.ts b/invokeai/frontend/web/src/features/controlNet/store/constants.ts index 00f5377e00..f8f9c38619 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/constants.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/constants.ts @@ -4,7 +4,7 @@ import { } from './types'; type ControlNetProcessorsDict = Record< - string, + ControlNetProcessorType, { type: ControlNetProcessorType | 'none'; label: string; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 0df907d463..8f391521d6 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -96,8 +96,11 @@ export const controlNetSlice = createSlice({ }> ) => { const { sourceControlNetId, newControlNetId } = action.payload; - - const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]); + const oldControlNet = state.controlNets[sourceControlNetId]; + if (!oldControlNet) { + return; + } + const newControlnet = cloneDeep(oldControlNet); newControlnet.controlNetId = newControlNetId; state.controlNets[newControlNetId] = newControlnet; }, @@ -124,8 +127,11 @@ export const controlNetSlice = createSlice({ action: PayloadAction<{ controlNetId: string }> ) => { const { controlNetId } = action.payload; - state.controlNets[controlNetId].isEnabled = - !state.controlNets[controlNetId].isEnabled; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + cn.isEnabled = !cn.isEnabled; }, controlNetImageChanged: ( state, @@ -135,12 +141,14 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId, controlImage } = action.payload; - state.controlNets[controlNetId].controlImage = controlImage; - state.controlNets[controlNetId].processedControlImage = null; - if ( - controlImage !== null && - state.controlNets[controlNetId].processorType !== 'none' - ) { + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.controlImage = controlImage; + cn.processedControlImage = null; + if (controlImage !== null && cn.processorType !== 'none') { state.pendingControlImages.push(controlNetId); } }, @@ -152,8 +160,12 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId, processedControlImage } = action.payload; - state.controlNets[controlNetId].processedControlImage = - processedControlImage; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.processedControlImage = processedControlImage; state.pendingControlImages = state.pendingControlImages.filter( (id) => id !== controlNetId ); @@ -166,10 +178,15 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId, model } = action.payload; - state.controlNets[controlNetId].model = model; - state.controlNets[controlNetId].processedControlImage = null; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } - if (state.controlNets[controlNetId].shouldAutoConfig) { + cn.model = model; + cn.processedControlImage = null; + + if (cn.shouldAutoConfig) { let processorType: ControlNetProcessorType | undefined = undefined; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { @@ -180,14 +197,13 @@ export const controlNetSlice = createSlice({ } if (processorType) { - state.controlNets[controlNetId].processorType = processorType; - state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ - processorType - ].default as RequiredControlNetProcessorNode; + cn.processorType = processorType; + cn.processorNode = CONTROLNET_PROCESSORS[processorType] + .default as RequiredControlNetProcessorNode; } else { - state.controlNets[controlNetId].processorType = 'none'; - state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS - .none.default as RequiredControlNetProcessorNode; + cn.processorType = 'none'; + cn.processorNode = CONTROLNET_PROCESSORS.none + .default as RequiredControlNetProcessorNode; } } }, @@ -196,28 +212,48 @@ export const controlNetSlice = createSlice({ action: PayloadAction<{ controlNetId: string; weight: number }> ) => { const { controlNetId, weight } = action.payload; - state.controlNets[controlNetId].weight = weight; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.weight = weight; }, controlNetBeginStepPctChanged: ( state, action: PayloadAction<{ controlNetId: string; beginStepPct: number }> ) => { const { controlNetId, beginStepPct } = action.payload; - state.controlNets[controlNetId].beginStepPct = beginStepPct; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.beginStepPct = beginStepPct; }, controlNetEndStepPctChanged: ( state, action: PayloadAction<{ controlNetId: string; endStepPct: number }> ) => { const { controlNetId, endStepPct } = action.payload; - state.controlNets[controlNetId].endStepPct = endStepPct; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.endStepPct = endStepPct; }, controlNetControlModeChanged: ( state, action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }> ) => { const { controlNetId, controlMode } = action.payload; - state.controlNets[controlNetId].controlMode = controlMode; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.controlMode = controlMode; }, controlNetResizeModeChanged: ( state, @@ -227,7 +263,12 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId, resizeMode } = action.payload; - state.controlNets[controlNetId].resizeMode = resizeMode; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.resizeMode = resizeMode; }, controlNetProcessorParamsChanged: ( state, @@ -240,12 +281,17 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId, changes } = action.payload; - const processorNode = state.controlNets[controlNetId].processorNode; - state.controlNets[controlNetId].processorNode = { + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + const processorNode = cn.processorNode; + cn.processorNode = { ...processorNode, ...changes, }; - state.controlNets[controlNetId].shouldAutoConfig = false; + cn.shouldAutoConfig = false; }, controlNetProcessorTypeChanged: ( state, @@ -255,12 +301,16 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId, processorType } = action.payload; - state.controlNets[controlNetId].processedControlImage = null; - state.controlNets[controlNetId].processorType = processorType; - state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ - processorType - ].default as RequiredControlNetProcessorNode; - state.controlNets[controlNetId].shouldAutoConfig = false; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + cn.processedControlImage = null; + cn.processorType = processorType; + cn.processorNode = CONTROLNET_PROCESSORS[processorType] + .default as RequiredControlNetProcessorNode; + cn.shouldAutoConfig = false; }, controlNetAutoConfigToggled: ( state, @@ -269,37 +319,36 @@ export const controlNetSlice = createSlice({ }> ) => { const { controlNetId } = action.payload; - const newShouldAutoConfig = - !state.controlNets[controlNetId].shouldAutoConfig; + const cn = state.controlNets[controlNetId]; + if (!cn) { + return; + } + + const newShouldAutoConfig = !cn.shouldAutoConfig; if (newShouldAutoConfig) { // manage the processor for the user let processorType: ControlNetProcessorType | undefined = undefined; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { - if ( - state.controlNets[controlNetId].model?.model_name.includes( - modelSubstring - ) - ) { + if (cn.model?.model_name.includes(modelSubstring)) { processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring]; break; } } if (processorType) { - state.controlNets[controlNetId].processorType = processorType; - state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ - processorType - ].default as RequiredControlNetProcessorNode; + cn.processorType = processorType; + cn.processorNode = CONTROLNET_PROCESSORS[processorType] + .default as RequiredControlNetProcessorNode; } else { - state.controlNets[controlNetId].processorType = 'none'; - state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS - .none.default as RequiredControlNetProcessorNode; + cn.processorType = 'none'; + cn.processorNode = CONTROLNET_PROCESSORS.none + .default as RequiredControlNetProcessorNode; } } - state.controlNets[controlNetId].shouldAutoConfig = newShouldAutoConfig; + cn.shouldAutoConfig = newShouldAutoConfig; }, controlNetReset: () => { return { ...initialControlNetState }; @@ -307,9 +356,11 @@ export const controlNetSlice = createSlice({ }, extraReducers: (builder) => { builder.addCase(controlNetImageProcessed, (state, action) => { - if ( - state.controlNets[action.payload.controlNetId].controlImage !== null - ) { + const cn = state.controlNets[action.payload.controlNetId]; + if (!cn) { + return; + } + if (cn.controlImage !== null) { state.pendingControlImages.push(action.payload.controlNetId); } }); diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageButton.tsx b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx similarity index 100% rename from invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageButton.tsx rename to invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageButton.tsx diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageModal.tsx similarity index 70% rename from invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx rename to invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageModal.tsx index 0e72ea96ad..0d8ecfbae6 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx +++ b/invokeai/frontend/web/src/features/deleteImageModal/components/DeleteImageModal.tsx @@ -15,30 +15,42 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIButton from 'common/components/IAIButton'; import IAISwitch from 'common/components/IAISwitch'; import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice'; - import { stateSelector } from 'app/store/store'; +import { some } from 'lodash-es'; import { ChangeEvent, memo, useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { imageDeletionConfirmed } from '../store/actions'; -import { selectImageUsage } from '../store/imageDeletionSelectors'; -import { - imageToDeleteCleared, - isModalOpenChanged, -} from '../store/imageDeletionSlice'; +import { getImageUsage, selectImageUsage } from '../store/selectors'; +import { imageDeletionCanceled, isModalOpenChanged } from '../store/slice'; import ImageUsageMessage from './ImageUsageMessage'; +import { ImageUsage } from '../store/types'; const selector = createSelector( [stateSelector, selectImageUsage], - ({ system, config, imageDeletion }, imageUsage) => { + (state, imagesUsage) => { + const { system, config, deleteImageModal } = state; const { shouldConfirmOnDelete } = system; const { canRestoreDeletedImagesFromBin } = config; - const { imageToDelete, isModalOpen } = imageDeletion; + const { imagesToDelete, isModalOpen } = deleteImageModal; + + const allImageUsage = (imagesToDelete ?? []).map(({ image_name }) => + getImageUsage(state, image_name) + ); + + const imageUsageSummary: ImageUsage = { + isInitialImage: some(allImageUsage, (i) => i.isInitialImage), + isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage), + isNodesImage: some(allImageUsage, (i) => i.isNodesImage), + isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage), + }; + return { shouldConfirmOnDelete, canRestoreDeletedImagesFromBin, - imageToDelete, - imageUsage, + imagesToDelete, + imagesUsage, isModalOpen, + imageUsageSummary, }; }, defaultSelectorOptions @@ -51,9 +63,10 @@ const DeleteImageModal = () => { const { shouldConfirmOnDelete, canRestoreDeletedImagesFromBin, - imageToDelete, - imageUsage, + imagesToDelete, + imagesUsage, isModalOpen, + imageUsageSummary, } = useAppSelector(selector); const handleChangeShouldConfirmOnDelete = useCallback( @@ -63,17 +76,19 @@ const DeleteImageModal = () => { ); const handleClose = useCallback(() => { - dispatch(imageToDeleteCleared()); + dispatch(imageDeletionCanceled()); dispatch(isModalOpenChanged(false)); }, [dispatch]); const handleDelete = useCallback(() => { - if (!imageToDelete || !imageUsage) { + if (!imagesToDelete.length || !imagesUsage.length) { return; } - dispatch(imageToDeleteCleared()); - dispatch(imageDeletionConfirmed({ imageDTO: imageToDelete, imageUsage })); - }, [dispatch, imageToDelete, imageUsage]); + dispatch(imageDeletionCanceled()); + dispatch( + imageDeletionConfirmed({ imageDTOs: imagesToDelete, imagesUsage }) + ); + }, [dispatch, imagesToDelete, imagesUsage]); const cancelRef = useRef(null); @@ -92,7 +107,7 @@ const DeleteImageModal = () => { - + {canRestoreDeletedImagesFromBin diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/ImageUsageMessage.tsx b/invokeai/frontend/web/src/features/deleteImageModal/components/ImageUsageMessage.tsx similarity index 100% rename from invokeai/frontend/web/src/features/imageDeletion/components/ImageUsageMessage.tsx rename to invokeai/frontend/web/src/features/deleteImageModal/components/ImageUsageMessage.tsx diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/actions.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/actions.ts similarity index 65% rename from invokeai/frontend/web/src/features/imageDeletion/store/actions.ts rename to invokeai/frontend/web/src/features/deleteImageModal/store/actions.ts index c67d7d944d..def27c9954 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/store/actions.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/actions.ts @@ -3,6 +3,6 @@ import { ImageDTO } from 'services/api/types'; import { ImageUsage } from './types'; export const imageDeletionConfirmed = createAction<{ - imageDTO: ImageDTO; - imageUsage: ImageUsage; -}>('imageDeletion/imageDeletionConfirmed'); + imageDTOs: ImageDTO[]; + imagesUsage: ImageUsage[]; +}>('deleteImageModal/imageDeletionConfirmed'); diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/initialState.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/initialState.ts new file mode 100644 index 0000000000..198d4ca51f --- /dev/null +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/initialState.ts @@ -0,0 +1,6 @@ +import { DeleteImageState } from './types'; + +export const initialDeleteImageState: DeleteImageState = { + imagesToDelete: [], + isModalOpen: false, +}; diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSelectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts similarity index 84% rename from invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSelectors.ts rename to invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts index bd8e117496..310521f32a 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSelectors.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts @@ -39,17 +39,17 @@ export const getImageUsage = (state: RootState, image_name: string) => { export const selectImageUsage = createSelector( [(state: RootState) => state], (state) => { - const { imageToDelete } = state.imageDeletion; + const { imagesToDelete } = state.deleteImageModal; - if (!imageToDelete) { - return; + if (!imagesToDelete.length) { + return []; } - const { image_name } = imageToDelete; + const imagesUsage = imagesToDelete.map((i) => + getImageUsage(state, i.image_name) + ); - const imageUsage = getImageUsage(state, image_name); - - return imageUsage; + return imagesUsage; }, defaultSelectorOptions ); diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/slice.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/slice.ts new file mode 100644 index 0000000000..6569009666 --- /dev/null +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/slice.ts @@ -0,0 +1,28 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { ImageDTO } from 'services/api/types'; +import { initialDeleteImageState } from './initialState'; + +const deleteImageModal = createSlice({ + name: 'deleteImageModal', + initialState: initialDeleteImageState, + reducers: { + isModalOpenChanged: (state, action: PayloadAction) => { + state.isModalOpen = action.payload; + }, + imagesToDeleteSelected: (state, action: PayloadAction) => { + state.imagesToDelete = action.payload; + }, + imageDeletionCanceled: (state) => { + state.imagesToDelete = []; + state.isModalOpen = false; + }, + }, +}); + +export const { + isModalOpenChanged, + imagesToDeleteSelected, + imageDeletionCanceled, +} = deleteImageModal.actions; + +export default deleteImageModal.reducer; diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/types.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/types.ts new file mode 100644 index 0000000000..2beaa8ca2e --- /dev/null +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/types.ts @@ -0,0 +1,13 @@ +import { ImageDTO } from 'services/api/types'; + +export type DeleteImageState = { + imagesToDelete: ImageDTO[]; + isModalOpen: boolean; +}; + +export type ImageUsage = { + isInitialImage: boolean; + isCanvasImage: boolean; + isNodesImage: boolean; + isControlNetImage: boolean; +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx index 9f02a29f10..96d17b548e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardAutoAddSelect.tsx @@ -56,7 +56,7 @@ const BoardAutoAddSelect = () => { return; } - dispatch(autoAddBoardIdChanged(v === 'none' ? undefined : v)); + dispatch(autoAddBoardIdChanged(v)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index 2774288612..0667c05435 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -11,10 +11,11 @@ import { BoardDTO } from 'services/api/types'; import { menuListMotionProps } from 'theme/components/menu'; import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems'; import NoBoardContextMenuItems from './NoBoardContextMenuItems'; +import { BoardId } from 'features/gallery/store/types'; type Props = { board?: BoardDTO; - board_id?: string; + board_id: BoardId; children: ContextMenuProps['children']; setBoardToDelete?: (board?: BoardDTO) => void; }; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BatchBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BatchBoard.tsx deleted file mode 100644 index a7a3040cce..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BatchBoard.tsx +++ /dev/null @@ -1,43 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { AddToBatchDropData } from 'app/components/ImageDnd/typesafeDnd'; -import { stateSelector } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { boardIdSelected } from 'features/gallery/store/gallerySlice'; -import { useCallback } from 'react'; -import { FaLayerGroup } from 'react-icons/fa'; -import { useDispatch } from 'react-redux'; -import GenericBoard from './GenericBoard'; - -const selector = createSelector(stateSelector, (state) => { - return { - count: state.gallery.batchImageNames.length, - }; -}); - -const BatchBoard = ({ isSelected }: { isSelected: boolean }) => { - const dispatch = useDispatch(); - const { count } = useAppSelector(selector); - - const handleBatchBoardClick = useCallback(() => { - dispatch(boardIdSelected('batch')); - }, [dispatch]); - - const droppableData: AddToBatchDropData = { - id: 'batch-board', - actionType: 'ADD_TO_BATCH', - }; - - return ( - - ); -}; - -export default BatchBoard; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx index 512fced67c..cb3474f6bd 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsList.tsx @@ -15,10 +15,9 @@ import NoBoardBoard from './NoBoardBoard'; const selector = createSelector( [stateSelector], - ({ boards, gallery }) => { - const { searchText } = boards; - const { selectedBoardId } = gallery; - return { selectedBoardId, searchText }; + ({ gallery }) => { + const { selectedBoardId, boardSearchText } = gallery; + return { selectedBoardId, boardSearchText }; }, defaultSelectorOptions ); @@ -29,11 +28,11 @@ type Props = { const BoardsList = (props: Props) => { const { isOpen } = props; - const { selectedBoardId, searchText } = useAppSelector(selector); + const { selectedBoardId, boardSearchText } = useAppSelector(selector); const { data: boards } = useListAllBoardsQuery(); - const filteredBoards = searchText + const filteredBoards = boardSearchText ? boards?.filter((board) => - board.board_name.toLowerCase().includes(searchText.toLowerCase()) + board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()) ) : boards; const [boardToDelete, setBoardToDelete] = useState(); @@ -75,7 +74,7 @@ const BoardsList = (props: Props) => { }} > - + {filteredBoards && filteredBoards.map((board) => ( diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx index 800ffc651f..d7db96a938 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardsSearch.tsx @@ -9,7 +9,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { setBoardSearchText } from 'features/gallery/store/boardSlice'; +import { boardSearchTextChanged } from 'features/gallery/store/gallerySlice'; import { ChangeEvent, KeyboardEvent, @@ -21,27 +21,27 @@ import { const selector = createSelector( [stateSelector], - ({ boards }) => { - const { searchText } = boards; - return { searchText }; + ({ gallery }) => { + const { boardSearchText } = gallery; + return { boardSearchText }; }, defaultSelectorOptions ); const BoardsSearch = () => { const dispatch = useAppDispatch(); - const { searchText } = useAppSelector(selector); + const { boardSearchText } = useAppSelector(selector); const inputRef = useRef(null); const handleBoardSearch = useCallback( (searchTerm: string) => { - dispatch(setBoardSearchText(searchTerm)); + dispatch(boardSearchTextChanged(searchTerm)); }, [dispatch] ); const clearBoardSearch = useCallback(() => { - dispatch(setBoardSearchText('')); + dispatch(boardSearchTextChanged('')); }, [dispatch]); const handleKeydown = useCallback( @@ -74,11 +74,11 @@ const BoardsSearch = () => { - {searchText && searchText.length && ( + {boardSearchText && boardSearchText.length && ( { setIsHovered(false); }, []); + + const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id); + const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id); + const tooltip = useMemo(() => { + if (!imagesTotal || !assetsTotal) { + return undefined; + } + return `${imagesTotal} image${ + imagesTotal > 1 ? 's' : '' + }, ${assetsTotal} asset${assetsTotal > 1 ? 's' : ''}`; + }, [assetsTotal, imagesTotal]); + const { currentData: coverImage } = useGetImageDTOQuery( board.cover_image_name ?? skipToken ); @@ -84,10 +101,10 @@ const GalleryBoard = memo( const [updateBoard, { isLoading: isUpdateBoardLoading }] = useUpdateBoardMutation(); - const droppableData: MoveBoardDropData = useMemo( + const droppableData: AddToBoardDropData = useMemo( () => ({ id: board_id, - actionType: 'MOVE_BOARD', + actionType: 'ADD_TO_BOARD', context: { boardId: board_id }, }), [board_id] @@ -148,60 +165,61 @@ const GalleryBoard = memo( setBoardToDelete={setBoardToDelete} > {(ref) => ( - - {coverImage?.thumbnail_url ? ( - - ) : ( - - + + {coverImage?.thumbnail_url ? ( + - - )} - {/* + + + )} + {/* */} - {isSelectedForAutoAdd && } - - - } + + - - + - - + overflow: 'hidden', + textOverflow: 'ellipsis', + }} + noOfLines={1} + /> + + + - Move} - /> - + Move} + /> + + )} diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index 118b2108f7..f1341b1146 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -1,6 +1,6 @@ import { Box, Flex, Image, Text } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd'; +import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; @@ -15,6 +15,7 @@ import { memo, useCallback, useMemo, useState } from 'react'; import { useBoardName } from 'services/api/hooks/useBoardName'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; + interface Props { isSelected: boolean; } @@ -33,26 +34,27 @@ const NoBoardBoard = memo(({ isSelected }: Props) => { const dispatch = useAppDispatch(); const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } = useAppSelector(selector); - const boardName = useBoardName(undefined); + const boardName = useBoardName('none'); const handleSelectBoard = useCallback(() => { - dispatch(boardIdSelected(undefined)); + dispatch(boardIdSelected('none')); if (autoAssignBoardOnClick && !isProcessing) { - dispatch(autoAddBoardIdChanged(undefined)); + dispatch(autoAddBoardIdChanged('none')); } }, [dispatch, autoAssignBoardOnClick, isProcessing]); const [isHovered, setIsHovered] = useState(false); + const handleMouseOver = useCallback(() => { setIsHovered(true); }, []); + const handleMouseOut = useCallback(() => { setIsHovered(false); }, []); - const droppableData: MoveBoardDropData = useMemo( + const droppableData: RemoveFromBoardDropData = useMemo( () => ({ id: 'no_board', - actionType: 'MOVE_BOARD', - context: { boardId: undefined }, + actionType: 'REMOVE_FROM_BOARD', }), [] ); @@ -72,7 +74,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => { h: 'full', }} > - + {(ref) => ( { alignItems: 'center', }} > - {/* */} invoke-ai-logo { }} /> - {/* - - {totalImages}/{totalAssets} - - */} - {!autoAddBoardId && } + {autoAddBoardId === 'none' && } void; }; -const DeleteImageModal = (props: Props) => { +const DeleteBoardModal = (props: Props) => { const { boardToDelete, setBoardToDelete } = props; const { t } = useTranslation(); const canRestoreDeletedImagesFromBin = useAppSelector( @@ -49,13 +49,10 @@ const DeleteImageModal = (props: Props) => { ); const imageUsageSummary: ImageUsage = { - isInitialImage: some(allImageUsage, (usage) => usage.isInitialImage), - isCanvasImage: some(allImageUsage, (usage) => usage.isCanvasImage), - isNodesImage: some(allImageUsage, (usage) => usage.isNodesImage), - isControlNetImage: some( - allImageUsage, - (usage) => usage.isControlNetImage - ), + isInitialImage: some(allImageUsage, (i) => i.isInitialImage), + isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage), + isNodesImage: some(allImageUsage, (i) => i.isNodesImage), + isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage), }; return { imageUsageSummary }; }), @@ -176,4 +173,4 @@ const DeleteImageModal = (props: Props) => { ); }; -export default memo(DeleteImageModal); +export default memo(DeleteBoardModal); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx deleted file mode 100644 index 49eb1502f3..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx +++ /dev/null @@ -1,93 +0,0 @@ -import { - AlertDialog, - AlertDialogBody, - AlertDialogContent, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogOverlay, - Box, - Flex, - Spinner, - Text, -} from '@chakra-ui/react'; -import IAIButton from 'common/components/IAIButton'; - -import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; -import { memo, useContext, useRef, useState } from 'react'; -import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; -import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext'; - -const UpdateImageBoardModal = () => { - // const boards = useSelector(selectBoardsAll); - const { data: boards, isFetching } = useListAllBoardsQuery(); - const { isOpen, onClose, handleAddToBoard, image } = useContext( - AddImageToBoardContext - ); - const [selectedBoard, setSelectedBoard] = useState(); - - const cancelRef = useRef(null); - - const currentBoard = boards?.find( - (board) => board.board_id === image?.board_id - ); - - return ( - - - - - {currentBoard ? 'Move Image to Board' : 'Add Image to Board'} - - - - - - {currentBoard && ( - - Moving this image from{' '} - {currentBoard.board_name} to - - )} - {isFetching ? ( - - ) : ( - setSelectedBoard(v)} - value={selectedBoard} - data={(boards ?? []).map((board) => ({ - label: board.board_name, - value: board.board_id, - }))} - /> - )} - - - - - Cancel - { - if (selectedBoard) { - handleAddToBoard(selectedBoard); - } - }} - ml={3} - > - {currentBoard ? 'Move' : 'Add'} - - - - - - ); -}; - -export default memo(UpdateImageBoardModal); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx index 7d25d6bc05..d62027769b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx @@ -9,16 +9,14 @@ import { MenuButton, MenuList, } from '@chakra-ui/react'; -// import { runESRGAN, runFacetool } from 'app/socketio/actions'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; - import { skipToken } from '@reduxjs/toolkit/dist/query'; import { useAppToaster } from 'app/components/Toaster'; import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; import { stateSelector } from 'app/store/store'; -import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton'; -import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; +import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton'; +import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { initialImageSelected } from 'features/parameters/store/actions'; @@ -109,13 +107,13 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { ); const { currentData: imageDTO } = useGetImageDTOQuery( - lastSelectedImage ?? skipToken + lastSelectedImage?.image_name ?? skipToken ); const { currentData: metadataData } = useGetImageMetadataQuery( debounceState.isPending() ? skipToken - : debouncedMetadataQueryArg ?? skipToken + : debouncedMetadataQueryArg?.image_name ?? skipToken ); const metadata = metadataData?.metadata; @@ -173,7 +171,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { if (!imageDTO) { return; } - dispatch(imageToDeleteSelected(imageDTO)); + dispatch(imagesToDeleteSelected([imageDTO])); }, [dispatch, imageDTO]); useHotkeys( diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx index fd7eaef46a..f78ee286ef 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx @@ -32,7 +32,7 @@ export const imagesSelector = createSelector( return { shouldShowImageDetails, shouldHidePreview, - imageName: lastSelectedImage, + imageName: lastSelectedImage?.image_name, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, @@ -57,8 +57,6 @@ const CurrentImagePreview = () => { const { handlePrevImage, handleNextImage, - prevImageId, - nextImageId, isOnLastImage, handleLoadMoreImages, areMoreImagesAvailable, @@ -70,7 +68,7 @@ const CurrentImagePreview = () => { () => { handlePrevImage(); }, - [prevImageId] + [handlePrevImage] ); useHotkeys( @@ -85,11 +83,11 @@ const CurrentImagePreview = () => { } }, [ - nextImageId, isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, + handleNextImage, ] ); diff --git a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx index 796cc542e7..5c32cc788e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GallerySettingsPopover.tsx @@ -5,17 +5,19 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIIconButton from 'common/components/IAIIconButton'; import IAIPopover from 'common/components/IAIPopover'; -import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISlider from 'common/components/IAISlider'; +import IAISwitch from 'common/components/IAISwitch'; import { autoAssignBoardOnClickChanged, setGalleryImageMinimumWidth, shouldAutoSwitchChanged, + shouldShowDeleteButtonChanged, } from 'features/gallery/store/gallerySlice'; -import { ChangeEvent } from 'react'; +import { ChangeEvent, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { FaWrench } from 'react-icons/fa'; import BoardAutoAddSelect from './Boards/BoardAutoAddSelect'; +import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; const selector = createSelector( [stateSelector], @@ -24,12 +26,14 @@ const selector = createSelector( galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick, + shouldShowDeleteButton, } = state.gallery; return { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick, + shouldShowDeleteButton, }; }, defaultSelectorOptions @@ -39,12 +43,37 @@ const GallerySettingsPopover = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { galleryImageMinimumWidth, shouldAutoSwitch, autoAssignBoardOnClick } = - useAppSelector(selector); + const { + galleryImageMinimumWidth, + shouldAutoSwitch, + autoAssignBoardOnClick, + shouldShowDeleteButton, + } = useAppSelector(selector); - const handleChangeGalleryImageMinimumWidth = (v: number) => { - dispatch(setGalleryImageMinimumWidth(v)); - }; + const handleChangeGalleryImageMinimumWidth = useCallback( + (v: number) => { + dispatch(setGalleryImageMinimumWidth(v)); + }, + [dispatch] + ); + + const handleResetGalleryImageMinimumWidth = useCallback(() => { + dispatch(setGalleryImageMinimumWidth(64)); + }, [dispatch]); + + const handleChangeAutoSwitch = useCallback( + (e: ChangeEvent) => { + dispatch(shouldAutoSwitchChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShowDeleteButton = useCallback( + (e: ChangeEvent) => { + dispatch(shouldShowDeleteButtonChanged(e.target.checked)); + }, + [dispatch] + ); return ( { /> } > - + { hideTooltip={true} label={t('gallery.galleryImageSize')} withReset - handleReset={() => dispatch(setGalleryImageMinimumWidth(64))} + handleReset={handleResetGalleryImageMinimumWidth} /> - ) => - dispatch(shouldAutoSwitchChanged(e.target.checked)) - } + onChange={handleChangeAutoSwitch} + /> + ['children']; }; +const selector = createSelector( + [stateSelector], + ({ gallery }) => { + const selectionCount = gallery.selection.length; + + return { selectionCount }; + }, + defaultSelectorOptions +); + const ImageContextMenu = ({ imageDTO, children }: Props) => { - // const selector = useMemo( - // () => - // createSelector( - // [stateSelector], - // ({ gallery }) => { - // const selectionCount = gallery.selection.length; - - // return { selectionCount }; - // }, - // defaultSelectorOptions - // ), - // [] - // ); - - // const { selectionCount } = useAppSelector(selector); + const { selectionCount } = useAppSelector(selector); const skipEvent = useCallback((e: MouseEvent) => { e.preventDefault(); @@ -38,8 +39,24 @@ const ImageContextMenu = ({ imageDTO, children }: Props) => { bg: 'transparent', _hover: { bg: 'transparent' }, }} - renderMenu={() => - imageDTO ? ( + renderMenu={() => { + if (!imageDTO) { + return null; + } + + if (selectionCount > 1) { + return ( + + + + ); + } + + return ( { > - ) : null - } + ); + }} > {children} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx index 62d2cb06f4..079fc43a4a 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx @@ -1,30 +1,30 @@ import { MenuItem } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + imagesToChangeSelected, + isModalOpenChanged, +} from 'features/changeBoardModal/store/slice'; +import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { useCallback } from 'react'; -import { FaFolder, FaFolderPlus, FaTrash } from 'react-icons/fa'; +import { FaFolder, FaTrash } from 'react-icons/fa'; const MultipleSelectionMenuItems = () => { - const handleAddSelectionToBoard = useCallback(() => { - // TODO: add selection to board - }, []); + const dispatch = useAppDispatch(); + const selection = useAppSelector((state) => state.gallery.selection); + + const handleChangeBoard = useCallback(() => { + dispatch(imagesToChangeSelected(selection)); + dispatch(isModalOpenChanged(true)); + }, [dispatch, selection]); const handleDeleteSelection = useCallback(() => { - // TODO: delete all selected images - }, []); - - const handleAddSelectionToBatch = useCallback(() => { - // TODO: add selection to batch - }, []); + dispatch(imagesToDeleteSelected(selection)); + }, [dispatch, selection]); return ( <> - } onClickCapture={handleAddSelectionToBoard}> - Move Selection to Board - - } - onClickCapture={handleAddSelectionToBatch} - > - Add Selection to Batch + } onClickCapture={handleChangeBoard}> + Change Board { const { imageDTO } = props; - const selector = useMemo( - () => - createSelector( - [stateSelector], - ({ gallery }) => { - const isInBatch = gallery.batchImageNames.includes( - imageDTO.image_name - ); - - return { isInBatch }; - }, - defaultSelectorOptions - ), - [imageDTO.image_name] - ); - - const { isInBatch } = useAppSelector(selector); const dispatch = useAppDispatch(); const { t } = useTranslation(); const toaster = useAppToaster(); const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; - const isBatchEnabled = useFeatureStatus('batches').isFeatureEnabled; - - const { onClickAddToBoard } = useContext(AddImageToBoardContext); const [debouncedMetadataQueryArg, debounceState] = useDebounce( imageDTO.image_name, @@ -92,14 +68,12 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { if (!imageDTO) { return; } - dispatch(imageToDeleteSelected(imageDTO)); + dispatch(imagesToDeleteSelected([imageDTO])); }, [dispatch, imageDTO]); const { recallBothPrompts, recallSeed, recallAllParameters } = useRecallParameters(); - const [removeFromBoard] = useRemoveImageFromBoardMutation(); - // Recall parameters handlers const handleRecallPrompt = useCallback(() => { recallBothPrompts( @@ -144,20 +118,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { recallAllParameters(metadata); }, [metadata, recallAllParameters]); - const handleAddToBoard = useCallback(() => { - onClickAddToBoard(imageDTO); - }, [imageDTO, onClickAddToBoard]); - - const handleRemoveFromBoard = useCallback(() => { - if (!imageDTO.board_id) { - return; - } - removeFromBoard({ imageDTO }); - }, [imageDTO, removeFromBoard]); - - const handleAddToBatch = useCallback(() => { - dispatch(imagesAddedToBatch([imageDTO.image_name])); - }, [dispatch, imageDTO.image_name]); + const handleChangeBoard = useCallback(() => { + dispatch(imagesToChangeSelected([imageDTO])); + dispatch(isModalOpenChanged(true)); + }, [dispatch, imageDTO]); const handleCopyImage = useCallback(() => { copyImageToClipboard(imageDTO.image_url); @@ -229,23 +193,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { {t('parameters.sendToUnifiedCanvas')} )} - {isBatchEnabled && ( - } - isDisabled={isInBatch} - onClickCapture={handleAddToBatch} - > - Add to Batch - - )} - } onClickCapture={handleAddToBoard}> - {imageDTO.board_id ? 'Change Board' : 'Add to Board'} + } onClickCapture={handleChangeBoard}> + Change Board - {imageDTO.board_id && ( - } onClickCapture={handleRemoveFromBoard}> - Remove from Board - - )} } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 5b2072bfc4..f2ff2ad30b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -20,16 +20,14 @@ import BoardsList from './Boards/BoardsList/BoardsList'; import GalleryBoardName from './GalleryBoardName'; import GalleryPinButton from './GalleryPinButton'; import GallerySettingsPopover from './GallerySettingsPopover'; -import BatchImageGrid from './ImageGrid/BatchImageGrid'; import GalleryImageGrid from './ImageGrid/GalleryImageGrid'; const selector = createSelector( [stateSelector], (state) => { - const { selectedBoardId, galleryView } = state.gallery; + const { galleryView } = state.gallery; return { - selectedBoardId, galleryView, }; }, @@ -39,7 +37,7 @@ const selector = createSelector( const ImageGalleryContent = () => { const resizeObserverRef = useRef(null); const galleryGridRef = useRef(null); - const { selectedBoardId, galleryView } = useAppSelector(selector); + const { galleryView } = useAppSelector(selector); const dispatch = useAppDispatch(); const { isOpen: isBoardListOpen, onToggle: onToggleBoardList } = useDisclosure(); @@ -130,12 +128,7 @@ const ImageGalleryContent = () => { - - {selectedBoardId === 'batch' ? ( - - ) : ( - - )} + ); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImage.tsx deleted file mode 100644 index 528e8cc06f..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImage.tsx +++ /dev/null @@ -1,122 +0,0 @@ -import { Box } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIDndImage from 'common/components/IAIDndImage'; -import IAIErrorLoadingImageFallback from 'common/components/IAIErrorLoadingImageFallback'; -import IAIFillSkeleton from 'common/components/IAIFillSkeleton'; -import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu'; -import { imagesRemovedFromBatch } from 'features/gallery/store/gallerySlice'; -import { memo, useCallback, useMemo } from 'react'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; - -const makeSelector = (image_name: string) => - createSelector( - [stateSelector], - (state) => ({ - selectionCount: state.gallery.selection.length, - selection: state.gallery.selection, - isSelected: state.gallery.selection.includes(image_name), - }), - defaultSelectorOptions - ); - -type BatchImageProps = { - imageName: string; -}; - -const BatchImage = (props: BatchImageProps) => { - const dispatch = useAppDispatch(); - const { imageName } = props; - const { - currentData: imageDTO, - isLoading, - isError, - } = useGetImageDTOQuery(imageName); - const selector = useMemo(() => makeSelector(imageName), [imageName]); - - const { isSelected, selectionCount, selection } = useAppSelector(selector); - - const handleClickRemove = useCallback(() => { - dispatch(imagesRemovedFromBatch([imageName])); - }, [dispatch, imageName]); - - // const handleClick = useCallback( - // (e: MouseEvent) => { - // if (e.shiftKey) { - // dispatch(imageRangeEndSelected(imageName)); - // } else if (e.ctrlKey || e.metaKey) { - // dispatch(imageSelectionToggled(imageName)); - // } else { - // dispatch(imageSelected(imageName)); - // } - // }, - // [dispatch, imageName] - // ); - - const draggableData = useMemo(() => { - if (selectionCount > 1) { - return { - id: 'batch', - payloadType: 'IMAGE_NAMES', - payload: { image_names: selection }, - }; - } - - if (imageDTO) { - return { - id: 'batch', - payloadType: 'IMAGE_DTO', - payload: { imageDTO }, - }; - } - }, [imageDTO, selection, selectionCount]); - - if (isLoading) { - return ; - } - - if (isError || !imageDTO) { - return ; - } - - return ( - - - {(ref) => ( - - - - )} - - - ); -}; - -export default memo(BatchImage); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImageGrid.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImageGrid.tsx deleted file mode 100644 index feaa47403d..0000000000 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/BatchImageGrid.tsx +++ /dev/null @@ -1,87 +0,0 @@ -import { Box } from '@chakra-ui/react'; -import { useAppSelector } from 'app/store/storeHooks'; -import { useOverlayScrollbars } from 'overlayscrollbars-react'; - -import { memo, useEffect, useRef, useState } from 'react'; -import { useTranslation } from 'react-i18next'; -import { FaImage } from 'react-icons/fa'; - -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { VirtuosoGrid } from 'react-virtuoso'; -import BatchImage from './BatchImage'; -import ItemContainer from './ImageGridItemContainer'; -import ListContainer from './ImageGridListContainer'; - -const selector = createSelector( - [stateSelector], - (state) => { - return { - imageNames: state.gallery.batchImageNames, - }; - }, - defaultSelectorOptions -); - -const BatchImageGrid = () => { - const { t } = useTranslation(); - const rootRef = useRef(null); - const [scroller, setScroller] = useState(null); - const [initialize, osInstance] = useOverlayScrollbars({ - defer: true, - options: { - scrollbars: { - visibility: 'auto', - autoHide: 'leave', - autoHideDelay: 1300, - theme: 'os-theme-dark', - }, - overflow: { x: 'hidden' }, - }, - }); - - const { imageNames } = useAppSelector(selector); - - useEffect(() => { - const { current: root } = rootRef; - if (scroller && root) { - initialize({ - target: root, - elements: { - viewport: scroller, - }, - }); - } - return () => osInstance()?.destroy(); - }, [scroller, initialize, osInstance]); - - if (imageNames.length) { - return ( - - ( - - )} - /> - - ); - } - - return ( - - ); -}; - -export default memo(BatchImageGrid); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index 6a5d28a9ba..c9eee5f1f5 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -1,27 +1,18 @@ import { Box, Flex } from '@chakra-ui/react'; -import { createSelector } from '@reduxjs/toolkit'; -import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; -import { stateSelector } from 'app/store/store'; +import { + ImageDTOsDraggableData, + ImageDraggableData, + TypesafeDraggableData, +} from 'app/components/ImageDnd/typesafeDnd'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIFillSkeleton from 'common/components/IAIFillSkeleton'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; +import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts'; +import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { FaTrash } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -export const makeSelector = (image_name: string) => - createSelector( - [stateSelector], - ({ gallery }) => ({ - isSelected: gallery.selection.includes(image_name), - selectionCount: gallery.selection.length, - selection: gallery.selection, - }), - defaultSelectorOptions - ); - interface HoverableImageProps { imageName: string; } @@ -30,22 +21,12 @@ const GalleryImage = (props: HoverableImageProps) => { const dispatch = useAppDispatch(); const { imageName } = props; const { currentData: imageDTO } = useGetImageDTOQuery(imageName); - const localSelector = useMemo(() => makeSelector(imageName), [imageName]); + const shouldShowDeleteButton = useAppSelector( + (state) => state.gallery.shouldShowDeleteButton + ); - const { isSelected, selectionCount, selection } = - useAppSelector(localSelector); - - const handleClick = useCallback(() => { - // disable multiselect for now - // if (e.shiftKey) { - // dispatch(imageRangeEndSelected(imageName)); - // } else if (e.ctrlKey || e.metaKey) { - // dispatch(imageSelectionToggled(imageName)); - // } else { - // dispatch(imageSelected(imageName)); - // } - dispatch(imageSelected(imageName)); - }, [dispatch, imageName]); + const { handleClick, isSelected, selection, selectionCount } = + useMultiselect(imageDTO); const handleDelete = useCallback( (e: MouseEvent) => { @@ -53,26 +34,28 @@ const GalleryImage = (props: HoverableImageProps) => { if (!imageDTO) { return; } - dispatch(imageToDeleteSelected(imageDTO)); + dispatch(imagesToDeleteSelected([imageDTO])); }, [dispatch, imageDTO] ); const draggableData = useMemo(() => { if (selectionCount > 1) { - return { + const data: ImageDTOsDraggableData = { id: 'gallery-image', - payloadType: 'IMAGE_NAMES', - payload: { image_names: selection }, + payloadType: 'IMAGE_DTOS', + payload: { imageDTOs: selection }, }; + return data; } if (imageDTO) { - return { + const data: ImageDraggableData = { id: 'gallery-image', payloadType: 'IMAGE_DTO', payload: { imageDTO }, }; + return data; } }, [imageDTO, selection, selectionCount]); @@ -103,9 +86,9 @@ const GalleryImage = (props: HoverableImageProps) => { isUploadDisabled={true} thumbnail={true} withHoverOverlay - // resetIcon={} - // resetTooltip="Delete image" - // withResetIcon // removed bc it's too easy to accidentally delete images + resetIcon={} + resetTooltip="Delete image" + withResetIcon={shouldShowDeleteButton} // removed bc it's too easy to accidentally delete images /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index df574c860b..c0821c2226 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,6 +1,6 @@ import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useCallback } from 'react'; -import { UnsafeImageMetadata } from 'services/api/endpoints/images'; +import { UnsafeImageMetadata } from 'services/api/types'; import ImageMetadataItem from './ImageMetadataItem'; type Props = { diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts new file mode 100644 index 0000000000..b59a2f3d6f --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts.ts @@ -0,0 +1,93 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { uniq } from 'lodash-es'; +import { MouseEvent, useCallback, useMemo } from 'react'; +import { useListImagesQuery } from 'services/api/endpoints/images'; +import { ImageDTO } from 'services/api/types'; +import { selectionChanged } from '../store/gallerySlice'; +import { imagesSelectors } from 'services/api/util'; + +const selector = createSelector( + [stateSelector, selectListImagesBaseQueryArgs], + ({ gallery }, queryArgs) => { + const selection = gallery.selection; + + return { + queryArgs, + selection, + }; + }, + defaultSelectorOptions +); + +export const useMultiselect = (imageDTO?: ImageDTO) => { + const dispatch = useAppDispatch(); + const { queryArgs, selection } = useAppSelector(selector); + + const { imageDTOs } = useListImagesQuery(queryArgs, { + selectFromResult: (result) => ({ + imageDTOs: result.data ? imagesSelectors.selectAll(result.data) : [], + }), + }); + + const handleClick = useCallback( + (e: MouseEvent) => { + if (!imageDTO) { + return; + } + if (e.shiftKey) { + const rangeEndImageName = imageDTO.image_name; + const lastSelectedImage = selection[selection.length - 1]?.image_name; + const lastClickedIndex = imageDTOs.findIndex( + (n) => n.image_name === lastSelectedImage + ); + const currentClickedIndex = imageDTOs.findIndex( + (n) => n.image_name === rangeEndImageName + ); + if (lastClickedIndex > -1 && currentClickedIndex > -1) { + // We have a valid range! + const start = Math.min(lastClickedIndex, currentClickedIndex); + const end = Math.max(lastClickedIndex, currentClickedIndex); + const imagesToSelect = imageDTOs.slice(start, end + 1); + dispatch(selectionChanged(uniq(selection.concat(imagesToSelect)))); + } + } else if (e.ctrlKey || e.metaKey) { + if ( + selection.some((i) => i.image_name === imageDTO.image_name) && + selection.length > 1 + ) { + dispatch( + selectionChanged( + selection.filter((n) => n.image_name !== imageDTO.image_name) + ) + ); + } else { + dispatch(selectionChanged(uniq(selection.concat(imageDTO)))); + } + } else { + dispatch(selectionChanged([imageDTO])); + } + }, + [dispatch, imageDTO, imageDTOs, selection] + ); + + const isSelected = useMemo( + () => + imageDTO + ? selection.some((i) => i.image_name === imageDTO.image_name) + : false, + [imageDTO, selection] + ); + + const selectionCount = useMemo(() => selection.length, [selection.length]); + + return { + selection, + selectionCount, + isSelected, + handleClick, + }; +}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts b/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts index f2572a23b5..670dd7ee9f 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts @@ -4,14 +4,15 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { clamp, isEqual } from 'lodash-es'; import { useCallback } from 'react'; +import { boardsApi } from 'services/api/endpoints/boards'; import { - ListImagesArgs, - imagesAdapter, imagesApi, useLazyListImagesQuery, } from 'services/api/endpoints/images'; import { selectListImagesBaseQueryArgs } from '../store/gallerySelectors'; import { IMAGE_LIMIT } from '../store/types'; +import { ListImagesArgs } from 'services/api/types'; +import { imagesAdapter } from 'services/api/util'; export const nextPrevImageButtonsSelector = createSelector( [stateSelector, selectListImagesBaseQueryArgs], @@ -19,12 +20,21 @@ export const nextPrevImageButtonsSelector = createSelector( const { data, status } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state); + const { data: total } = + state.gallery.galleryView === 'images' + ? boardsApi.endpoints.getBoardImagesTotal.select( + baseQueryArgs.board_id ?? 'none' + )(state) + : boardsApi.endpoints.getBoardAssetsTotal.select( + baseQueryArgs.board_id ?? 'none' + )(state); + const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]; const isFetching = status === 'pending'; - if (!data || !lastSelectedImage || data.total === 0) { + if (!data || !lastSelectedImage || total === 0) { return { isFetching, queryArgs: baseQueryArgs, @@ -44,30 +54,30 @@ export const nextPrevImageButtonsSelector = createSelector( const images = selectors.selectAll(data); const currentImageIndex = images.findIndex( - (i) => i.image_name === lastSelectedImage + (i) => i.image_name === lastSelectedImage.image_name ); const nextImageIndex = clamp(currentImageIndex + 1, 0, images.length - 1); - const prevImageIndex = clamp(currentImageIndex - 1, 0, images.length - 1); const nextImageId = images[nextImageIndex]?.image_name; const prevImageId = images[prevImageIndex]?.image_name; - const nextImage = selectors.selectById(data, nextImageId); - const prevImage = selectors.selectById(data, prevImageId); + const nextImage = nextImageId + ? selectors.selectById(data, nextImageId) + : undefined; + const prevImage = prevImageId + ? selectors.selectById(data, prevImageId) + : undefined; const imagesLength = images.length; return { - isOnFirstImage: currentImageIndex === 0, - isOnLastImage: - !isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1, - areMoreImagesAvailable: (data?.total ?? 0) > imagesLength, + loadedImagesCount: images.length, + currentImageIndex, + areMoreImagesAvailable: (total ?? 0) > imagesLength, isFetching: status === 'pending', nextImage, prevImage, - nextImageId, - prevImageId, queryArgs, }; }, @@ -82,22 +92,22 @@ export const useNextPrevImage = () => { const dispatch = useAppDispatch(); const { - isOnFirstImage, - isOnLastImage, - nextImageId, - prevImageId, + nextImage, + prevImage, areMoreImagesAvailable, isFetching, queryArgs, + loadedImagesCount, + currentImageIndex, } = useAppSelector(nextPrevImageButtonsSelector); const handlePrevImage = useCallback(() => { - prevImageId && dispatch(imageSelected(prevImageId)); - }, [dispatch, prevImageId]); + prevImage && dispatch(imageSelected(prevImage)); + }, [dispatch, prevImage]); const handleNextImage = useCallback(() => { - nextImageId && dispatch(imageSelected(nextImageId)); - }, [dispatch, nextImageId]); + nextImage && dispatch(imageSelected(nextImage)); + }, [dispatch, nextImage]); const [listImages] = useLazyListImagesQuery(); @@ -108,10 +118,12 @@ export const useNextPrevImage = () => { return { handlePrevImage, handleNextImage, - isOnFirstImage, - isOnLastImage, - nextImageId, - prevImageId, + isOnFirstImage: currentImageIndex === 0, + isOnLastImage: + currentImageIndex !== undefined && + currentImageIndex === loadedImagesCount - 1, + nextImage, + prevImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, diff --git a/invokeai/frontend/web/src/features/gallery/store/actions.ts b/invokeai/frontend/web/src/features/gallery/store/actions.ts index 0e1b1ef2a0..9368fe6cf6 100644 --- a/invokeai/frontend/web/src/features/gallery/store/actions.ts +++ b/invokeai/frontend/web/src/features/gallery/store/actions.ts @@ -1,5 +1,5 @@ import { createAction } from '@reduxjs/toolkit'; -import { ImageUsage } from 'app/contexts/AddImageToBoardContext'; +import { ImageUsage } from 'features/deleteImageModal/store/types'; import { BoardDTO } from 'services/api/types'; export type RequestedBoardImagesDeletionArg = { diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts deleted file mode 100644 index ad43498e51..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { PayloadAction, createSlice } from '@reduxjs/toolkit'; - -type BoardsState = { - searchText: string; - updateBoardModalOpen: boolean; -}; - -export const initialBoardsState: BoardsState = { - updateBoardModalOpen: false, - searchText: '', -}; - -const boardsSlice = createSlice({ - name: 'boards', - initialState: initialBoardsState, - reducers: { - setBoardSearchText: (state, action: PayloadAction) => { - state.searchText = action.payload; - }, - setUpdateBoardModalOpen: (state, action: PayloadAction) => { - state.updateBoardModalOpen = action.payload; - }, - }, -}); - -export const { setBoardSearchText, setUpdateBoardModalOpen } = - boardsSlice.actions; - -export default boardsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts index b589550157..47e29456a0 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { ListImagesArgs } from 'services/api/endpoints/images'; +import { ListImagesArgs } from 'services/api/types'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, @@ -24,7 +24,7 @@ export const selectListImagesBaseQueryArgs = createSelector( galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES; const listImagesBaseQueryArgs: ListImagesArgs = { - board_id: selectedBoardId ?? 'none', + board_id: selectedBoardId, categories, offset: 0, limit: INITIAL_IMAGE_LIMIT, diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 9c65e818f4..3b0dd233f1 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,66 +1,32 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit'; -import { uniq } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; +import { imagesApi } from 'services/api/endpoints/images'; +import { ImageDTO } from 'services/api/types'; import { BoardId, GalleryState, GalleryView } from './types'; export const initialGalleryState: GalleryState = { selection: [], shouldAutoSwitch: true, - autoAddBoardId: undefined, autoAssignBoardOnClick: true, + autoAddBoardId: 'none', galleryImageMinimumWidth: 96, - selectedBoardId: undefined, + selectedBoardId: 'none', galleryView: 'images', - batchImageNames: [], - isBatchEnabled: false, + shouldShowDeleteButton: false, + boardSearchText: '', }; export const gallerySlice = createSlice({ name: 'gallery', initialState: initialGalleryState, reducers: { - imageRangeEndSelected: () => { - // TODO - }, - // imageRangeEndSelected: (state, action: PayloadAction) => { - // const rangeEndImageName = action.payload; - // const lastSelectedImage = state.selection[state.selection.length - 1]; - // const filteredImages = selectFilteredImagesLocal(state); - // const lastClickedIndex = filteredImages.findIndex( - // (n) => n.image_name === lastSelectedImage - // ); - // const currentClickedIndex = filteredImages.findIndex( - // (n) => n.image_name === rangeEndImageName - // ); - // if (lastClickedIndex > -1 && currentClickedIndex > -1) { - // // We have a valid range! - // const start = Math.min(lastClickedIndex, currentClickedIndex); - // const end = Math.max(lastClickedIndex, currentClickedIndex); - // const imagesToSelect = filteredImages - // .slice(start, end + 1) - // .map((i) => i.image_name); - // state.selection = uniq(state.selection.concat(imagesToSelect)); - // } - // }, - imageSelectionToggled: () => { - // TODO - }, - // imageSelectionToggled: (state, action: PayloadAction) => { - // TODO: multiselect - // if ( - // state.selection.includes(action.payload) && - // state.selection.length > 1 - // ) { - // state.selection = state.selection.filter( - // (imageName) => imageName !== action.payload - // ); - // } else { - // state.selection = uniq(state.selection.concat(action.payload)); - // } - imageSelected: (state, action: PayloadAction) => { + imageSelected: (state, action: PayloadAction) => { state.selection = action.payload ? [action.payload] : []; }, + selectionChanged: (state, action: PayloadAction) => { + state.selection = action.payload; + }, shouldAutoSwitchChanged: (state, action: PayloadAction) => { state.shouldAutoSwitch = action.payload; }, @@ -74,53 +40,28 @@ export const gallerySlice = createSlice({ state.selectedBoardId = action.payload; state.galleryView = 'images'; }, - isBatchEnabledChanged: (state, action: PayloadAction) => { - state.isBatchEnabled = action.payload; - }, - imagesAddedToBatch: (state, action: PayloadAction) => { - state.batchImageNames = uniq( - state.batchImageNames.concat(action.payload) - ); - }, - imagesRemovedFromBatch: (state, action: PayloadAction) => { - state.batchImageNames = state.batchImageNames.filter( - (imageName) => !action.payload.includes(imageName) - ); - - const newSelection = state.selection.filter( - (imageName) => !action.payload.includes(imageName) - ); - - if (newSelection.length) { - state.selection = newSelection; - return; - } - - state.selection = [state.batchImageNames[0]] ?? []; - }, - batchReset: (state) => { - state.batchImageNames = []; - state.selection = []; - }, - autoAddBoardIdChanged: ( - state, - action: PayloadAction - ) => { + autoAddBoardIdChanged: (state, action: PayloadAction) => { state.autoAddBoardId = action.payload; }, galleryViewChanged: (state, action: PayloadAction) => { state.galleryView = action.payload; }, + shouldShowDeleteButtonChanged: (state, action: PayloadAction) => { + state.shouldShowDeleteButton = action.payload; + }, + boardSearchTextChanged: (state, action: PayloadAction) => { + state.boardSearchText = action.payload; + }, }, extraReducers: (builder) => { builder.addMatcher(isAnyBoardDeleted, (state, action) => { const deletedBoardId = action.meta.arg.originalArgs; if (deletedBoardId === state.selectedBoardId) { - state.selectedBoardId = undefined; + state.selectedBoardId = 'none'; state.galleryView = 'images'; } if (deletedBoardId === state.autoAddBoardId) { - state.autoAddBoardId = undefined; + state.autoAddBoardId = 'none'; } }); builder.addMatcher( @@ -132,7 +73,7 @@ export const gallerySlice = createSlice({ } if (!boards.map((b) => b.board_id).includes(state.autoAddBoardId)) { - state.autoAddBoardId = undefined; + state.autoAddBoardId = 'none'; } } ); @@ -140,23 +81,21 @@ export const gallerySlice = createSlice({ }); export const { - imageRangeEndSelected, - imageSelectionToggled, imageSelected, shouldAutoSwitchChanged, autoAssignBoardOnClickChanged, setGalleryImageMinimumWidth, boardIdSelected, - isBatchEnabledChanged, - imagesAddedToBatch, - imagesRemovedFromBatch, autoAddBoardIdChanged, galleryViewChanged, + selectionChanged, + shouldShowDeleteButtonChanged, + boardSearchTextChanged, } = gallerySlice.actions; export default gallerySlice.reducer; const isAnyBoardDeleted = isAnyOf( - boardsApi.endpoints.deleteBoard.matchFulfilled, - boardsApi.endpoints.deleteBoardAndImages.matchFulfilled + imagesApi.endpoints.deleteBoard.matchFulfilled, + imagesApi.endpoints.deleteBoardAndImages.matchFulfilled ); diff --git a/invokeai/frontend/web/src/features/gallery/store/types.ts b/invokeai/frontend/web/src/features/gallery/store/types.ts index 298b792362..6860f6ea7b 100644 --- a/invokeai/frontend/web/src/features/gallery/store/types.ts +++ b/invokeai/frontend/web/src/features/gallery/store/types.ts @@ -1,4 +1,4 @@ -import { ImageCategory } from 'services/api/types'; +import { ImageCategory, ImageDTO } from 'services/api/types'; export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; export const ASSETS_CATEGORIES: ImageCategory[] = [ @@ -11,17 +11,16 @@ export const INITIAL_IMAGE_LIMIT = 100; export const IMAGE_LIMIT = 20; export type GalleryView = 'images' | 'assets'; -// export type BoardId = 'no_board' | (string & Record); -export type BoardId = string | undefined; +export type BoardId = 'none' | (string & Record); export type GalleryState = { - selection: string[]; + selection: ImageDTO[]; shouldAutoSwitch: boolean; - autoAddBoardId: string | undefined; autoAssignBoardOnClick: boolean; + autoAddBoardId: BoardId; galleryImageMinimumWidth: number; selectedBoardId: BoardId; galleryView: GalleryView; - batchImageNames: string[]; - isBatchEnabled: boolean; + shouldShowDeleteButton: boolean; + boardSearchText: string; }; diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts deleted file mode 100644 index 0bfd9a537d..0000000000 --- a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts +++ /dev/null @@ -1,37 +0,0 @@ -import { PayloadAction, createSlice } from '@reduxjs/toolkit'; -import { ImageDTO } from 'services/api/types'; - -type DeleteImageState = { - imageToDelete: ImageDTO | null; - isModalOpen: boolean; -}; - -export const initialDeleteImageState: DeleteImageState = { - imageToDelete: null, - isModalOpen: false, -}; - -const imageDeletion = createSlice({ - name: 'imageDeletion', - initialState: initialDeleteImageState, - reducers: { - isModalOpenChanged: (state, action: PayloadAction) => { - state.isModalOpen = action.payload; - }, - imageToDeleteSelected: (state, action: PayloadAction) => { - state.imageToDelete = action.payload; - }, - imageToDeleteCleared: (state) => { - state.imageToDelete = null; - state.isModalOpen = false; - }, - }, -}); - -export const { - isModalOpenChanged, - imageToDeleteSelected, - imageToDeleteCleared, -} = imageDeletion.actions; - -export default imageDeletion.reducer; diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/types.ts b/invokeai/frontend/web/src/features/imageDeletion/store/types.ts deleted file mode 100644 index b3f4dc9c8d..0000000000 --- a/invokeai/frontend/web/src/features/imageDeletion/store/types.ts +++ /dev/null @@ -1,6 +0,0 @@ -export type ImageUsage = { - isInitialImage: boolean; - isCanvasImage: boolean; - isNodesImage: boolean; - isControlNetImage: boolean; -}; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index f0067a85a2..10a1671933 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -39,11 +39,19 @@ export const loraSlice = createSlice({ action: PayloadAction<{ id: string; weight: number }> ) => { const { id, weight } = action.payload; - state.loras[id].weight = weight; + const lora = state.loras[id]; + if (!lora) { + return; + } + lora.weight = weight; }, loraWeightReset: (state, action: PayloadAction) => { const id = action.payload; - state.loras[id].weight = defaultLoRAConfig.weight; + const lora = state.loras[id]; + if (!lora) { + return; + } + lora.weight = defaultLoRAConfig.weight; }, }, }); diff --git a/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx b/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx index 669110fa54..d4a4f8d31f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx @@ -170,15 +170,17 @@ const NodeSearch = () => { // } if (key === 'Enter') { - let selectedNodeType: AnyInvocationType; + let selectedNodeType: AnyInvocationType | undefined; if (searchText.length > 0) { - selectedNodeType = filteredNodes[focusedIndex].item.type; + selectedNodeType = filteredNodes[focusedIndex]?.item.type; } else { - selectedNodeType = nodes[focusedIndex].type; + selectedNodeType = nodes[focusedIndex]?.type; } - addNode(selectedNodeType); + if (selectedNodeType) { + addNode(selectedNodeType); + } setShowNodeList(false); } diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 436396fb38..2e41081e95 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -79,9 +79,12 @@ const nodesSlice = createSlice({ ) => { const { nodeId, fieldName, value } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - + const input = state.nodes?.[nodeIndex]?.data?.inputs[fieldName]; + if (!input) { + return; + } if (nodeIndex > -1) { - state.nodes[nodeIndex].data.inputs[fieldName].value = value; + input.value = value; } }, imageCollectionFieldValueChanged: ( @@ -99,16 +102,19 @@ const nodesSlice = createSlice({ return; } - const currentValue = cloneDeep( - state.nodes[nodeIndex].data.inputs[fieldName].value - ); - - if (!currentValue) { - state.nodes[nodeIndex].data.inputs[fieldName].value = value; + const input = state.nodes?.[nodeIndex]?.data?.inputs[fieldName]; + if (!input) { return; } - state.nodes[nodeIndex].data.inputs[fieldName].value = uniqBy( + const currentValue = cloneDeep(input.value); + + if (!currentValue) { + input.value = value; + return; + } + + input.value = uniqBy( (currentValue as ImageField[]).concat(value), 'image_name' ); diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 83692533f7..de7d798c69 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -29,6 +29,8 @@ import { VaeInputFieldTemplate, VaeModelInputFieldTemplate, } from '../types/types'; +import { logger } from 'app/logging/logger'; +import { parseify } from 'common/util/serialize'; export type BaseFieldProperties = 'name' | 'title' | 'description'; @@ -50,7 +52,13 @@ export type BuildInputFieldArg = { */ export const refObjectToFieldType = ( refObject: OpenAPIV3.ReferenceObject -): keyof typeof FIELD_TYPE_MAP => refObject.$ref.split('/').slice(-1)[0]; +): keyof typeof FIELD_TYPE_MAP => { + const name = refObject.$ref.split('/').slice(-1)[0]; + if (!name) { + return 'UNKNOWN FIELD TYPE'; + } + return name; +}; const buildIntegerInputFieldTemplate = ({ schemaObject, @@ -428,7 +436,7 @@ export const getFieldType = ( let rawFieldType = ''; if (typeHints && name in typeHints) { - rawFieldType = typeHints[name]; + rawFieldType = typeHints[name] ?? 'UNKNOWN FIELD TYPE'; } else if (!schemaObject.type) { // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf if (schemaObject.allOf) { @@ -568,10 +576,23 @@ export const buildOutputFieldTemplates = ( // extract output schema name from ref const outputSchemaName = refObject.$ref.split('/').slice(-1)[0]; + if (!outputSchemaName) { + logger('nodes').error( + { refObject: parseify(refObject) }, + 'No output schema name found in ref object' + ); + throw 'No output schema name found in ref object'; + } + // get the output schema itself // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const outputSchema = openAPI.components!.schemas![outputSchemaName]; + if (!outputSchema) { + logger('nodes').error({ outputSchemaName }, 'Output schema not found'); + throw 'Output schema not found'; + } + if (isSchemaObject(outputSchema)) { const outputFields = reduce( outputSchema.properties as OpenAPIV3.SchemaObject, diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 418ed9278f..c4d2d35f8f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -16,7 +16,10 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { map } from 'lodash-es'; import { Fragment, memo, useCallback } from 'react'; import { FaPlus } from 'react-icons/fa'; -import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; +import { + controlNetModelsAdapter, + useGetControlNetModelsQuery, +} from 'services/api/endpoints/models'; import { v4 as uuidv4 } from 'uuid'; const selector = createSelector( @@ -42,7 +45,9 @@ const ParamControlNetCollapse = () => { const dispatch = useAppDispatch(); const { firstModel } = useGetControlNetModelsQuery(undefined, { selectFromResult: (result) => { - const firstModel = result.data?.entities[result.data?.ids[0]]; + const firstModel = result.data + ? controlNetModelsAdapter.getSelectors().selectAll(result.data)[0] + : undefined; return { firstModel, }; @@ -95,7 +100,7 @@ const ParamControlNetCollapse = () => { {controlNetsArray.map((c, i) => ( {i > 0 && } - + ))} diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index cb2361524d..907107e95e 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -12,7 +12,7 @@ import { } from 'features/sdxl/store/sdxlSlice'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { UnsafeImageMetadata } from 'services/api/endpoints/images'; +import { UnsafeImageMetadata } from 'services/api/types'; import { ImageDTO } from 'services/api/types'; import { initialImageSelected, modelSelected } from '../store/actions'; import { diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddCheckpoint.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddCheckpoint.tsx index fd5106b289..5f82483cd3 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddCheckpoint.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddCheckpoint.tsx @@ -28,9 +28,7 @@ export default function AdvancedAddCheckpoint( const advancedAddCheckpointForm = useForm({ initialValues: { - model_name: model_path - ? model_path.split('\\').splice(-1)[0].split('.')[0] - : '', + model_name: model_path?.split('\\').splice(-1)[0]?.split('.')[0] ?? '', base_model: 'sd-1', model_type: 'main', path: model_path ? model_path : '', diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx index 376631bd1f..ec2d3f037a 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx @@ -25,7 +25,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) { const advancedAddDiffusersForm = useForm({ initialValues: { - model_name: model_path ? model_path.split('\\').splice(-1)[0] : '', + model_name: model_path?.split('\\').splice(-1)[0] ?? '', base_model: 'sd-1', model_type: 'main', path: model_path ? model_path : '', diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index 4ad8fbaba6..6837a2e853 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -59,10 +59,10 @@ export default function MergeModelsPanel() { }, [sd1DiffusersModels, sd2DiffusersModels]); const [modelOne, setModelOne] = useState( - Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[0] + Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[0] ?? null ); const [modelTwo, setModelTwo] = useState( - Object.keys(modelsMap[baseModel as keyof typeof modelsMap])[1] + Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[1] ?? null ); const [modelThree, setModelThree] = useState(null); @@ -106,8 +106,9 @@ export default function MergeModelsPanel() { let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree]; modelsToMerge = modelsToMerge.filter((model) => model !== null); modelsToMerge.forEach((model) => { - if (model) { - models_names.push(model?.split('/')[2]); + const n = model?.split('/')?.[2]; + if (n) { + models_names.push(n); } }); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 1aec7d5c05..045745e206 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -47,13 +47,11 @@ export default function ModelConvert(props: ModelConvertProps) { }; const modelConvertHandler = () => { - const responseBody = { + const queryArg = { base_model: model.base_model, model_name: model.model_name, - params: { - convert_dest_directory: - saveLocation === 'Custom' ? customSaveLocation : undefined, - }, + convert_dest_directory: + saveLocation === 'Custom' ? customSaveLocation : undefined, }; if (saveLocation === 'Custom' && customSaveLocation === '') { @@ -74,14 +72,14 @@ export default function ModelConvert(props: ModelConvertProps) { title: `${t('modelManager.convertingModelBegin')}: ${ model.model_name }`, - status: 'success', + status: 'info', }) ) ); - convertModel(responseBody) + convertModel(queryArg) .unwrap() - .then((_) => { + .then(() => { dispatch( addToast( makeToast({ @@ -91,7 +89,7 @@ export default function ModelConvert(props: ModelConvertProps) { ) ); }) - .catch((_) => { + .catch(() => { dispatch( addToast( makeToast({ diff --git a/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts b/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts index fa152e9ce5..5427fa9d3b 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts @@ -2,12 +2,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { isEqual } from 'lodash-es'; -import { tabMap } from './tabMap'; +import { InvokeTabName, tabMap } from './tabMap'; import { UIState } from './uiTypes'; export const activeTabNameSelector = createSelector( (state: RootState) => state.ui, - (ui: UIState) => tabMap[ui.activeTab], + (ui: UIState) => tabMap[ui.activeTab] as InvokeTabName, { memoizeOptions: { equalityCheck: isEqual, diff --git a/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts b/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts deleted file mode 100644 index 2dc292321e..0000000000 --- a/invokeai/frontend/web/src/services/api/endpoints/boardImages.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { api } from '..'; - -export const boardImagesApi = api.injectEndpoints({ - endpoints: (build) => ({ - /** - * Board Images Queries - */ - // listBoardImages: build.query< - // OffsetPaginatedResults_ImageDTO_, - // ListBoardImagesArg - // >({ - // query: ({ board_id, offset, limit }) => ({ - // url: `board_images/${board_id}`, - // method: 'GET', - // }), - // providesTags: (result, error, arg) => { - // // any list of boardimages - // const tags: ApiFullTagDescription[] = [ - // { type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` }, - // ]; - // if (result) { - // // and individual tags for each boardimage - // tags.push( - // ...result.items.map(({ board_id, image_name }) => ({ - // type: 'BoardImage' as const, - // id: `${board_id}_${image_name}`, - // })) - // ); - // } - // return tags; - // }, - // }), - }), -}); - -// export const { useListBoardImagesQuery } = boardImagesApi; diff --git a/invokeai/frontend/web/src/services/api/endpoints/boards.ts b/invokeai/frontend/web/src/services/api/endpoints/boards.ts index 73b894b492..9d9fa11da8 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/boards.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/boards.ts @@ -1,28 +1,16 @@ -import { Update } from '@reduxjs/toolkit'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, } from 'features/gallery/store/types'; import { BoardDTO, - ImageDTO, + ListBoardsArg, OffsetPaginatedResults_BoardDTO_, + OffsetPaginatedResults_ImageDTO_, + UpdateBoardArg, } from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; -import { paths } from '../schema'; -import { getListImagesUrl, imagesAdapter, imagesApi } from './images'; - -type ListBoardsArg = NonNullable< - paths['/api/v1/boards/']['get']['parameters']['query'] ->; - -type UpdateBoardArg = - paths['/api/v1/boards/{board_id}']['patch']['parameters']['path'] & { - changes: paths['/api/v1/boards/{board_id}']['patch']['requestBody']['content']['application/json']; - }; - -type DeleteBoardResult = - paths['/api/v1/boards/{board_id}']['delete']['responses']['200']['content']['application/json']; +import { getListImagesUrl } from '../util'; export const boardsApi = api.injectEndpoints({ endpoints: (build) => ({ @@ -82,6 +70,44 @@ export const boardsApi = api.injectEndpoints({ keepUnusedDataFor: 0, }), + getBoardImagesTotal: build.query({ + query: (board_id) => ({ + url: getListImagesUrl({ + board_id: board_id ?? 'none', + categories: IMAGE_CATEGORIES, + is_intermediate: false, + limit: 0, + offset: 0, + }), + method: 'GET', + }), + providesTags: (result, error, arg) => [ + { type: 'BoardImagesTotal', id: arg ?? 'none' }, + ], + transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => { + return response.total; + }, + }), + + getBoardAssetsTotal: build.query({ + query: (board_id) => ({ + url: getListImagesUrl({ + board_id: board_id ?? 'none', + categories: ASSETS_CATEGORIES, + is_intermediate: false, + limit: 0, + offset: 0, + }), + method: 'GET', + }), + providesTags: (result, error, arg) => [ + { type: 'BoardAssetsTotal', id: arg ?? 'none' }, + ], + transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => { + return response.total; + }, + }), + /** * Boards Mutations */ @@ -105,176 +131,15 @@ export const boardsApi = api.injectEndpoints({ { type: 'Board', id: arg.board_id }, ], }), - - deleteBoard: build.mutation({ - query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), - invalidatesTags: (result, error, board_id) => [ - { type: 'Board', id: LIST_TAG }, - // invalidate the 'No Board' cache - { - type: 'ImageList', - id: getListImagesUrl({ - board_id: 'none', - categories: IMAGE_CATEGORIES, - }), - }, - { - type: 'ImageList', - id: getListImagesUrl({ - board_id: 'none', - categories: ASSETS_CATEGORIES, - }), - }, - { type: 'BoardImagesTotal', id: 'none' }, - { type: 'BoardAssetsTotal', id: 'none' }, - ], - async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) { - /** - * Cache changes for deleteBoard: - * - Update every image in the 'getImageDTO' cache that has the board_id - * - Update every image in the 'All Images' cache that has the board_id - * - Update every image in the 'All Assets' cache that has the board_id - * - Invalidate the 'No Board' cache: - * Ideally we'd be able to insert all deleted images into the cache, but we don't - * have access to the deleted images DTOs - only the names, and a network request - * for all of a board's DTOs could be very large. Instead, we invalidate the 'No Board' - * cache. - */ - - try { - const { data } = await queryFulfilled; - const { deleted_board_images } = data; - - // update getImageDTO caches - deleted_board_images.forEach((image_id) => { - dispatch( - imagesApi.util.updateQueryData( - 'getImageDTO', - image_id, - (draft) => { - draft.board_id = undefined; - } - ) - ); - }); - - // update 'All Images' & 'All Assets' caches - const queryArgsToUpdate = [ - { - categories: IMAGE_CATEGORIES, - }, - { - categories: ASSETS_CATEGORIES, - }, - ]; - - const updates: Update[] = deleted_board_images.map( - (image_name) => ({ - id: image_name, - changes: { board_id: undefined }, - }) - ); - - queryArgsToUpdate.forEach((queryArgs) => { - dispatch( - imagesApi.util.updateQueryData( - 'listImages', - queryArgs, - (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.updateMany(draft, updates); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; - } - ) - ); - }); - } catch { - //no-op - } - }, - }), - - deleteBoardAndImages: build.mutation({ - query: (board_id) => ({ - url: `boards/${board_id}`, - method: 'DELETE', - params: { include_images: true }, - }), - invalidatesTags: (result, error, board_id) => [ - { type: 'Board', id: LIST_TAG }, - { - type: 'ImageList', - id: getListImagesUrl({ - board_id: 'none', - categories: IMAGE_CATEGORIES, - }), - }, - { - type: 'ImageList', - id: getListImagesUrl({ - board_id: 'none', - categories: ASSETS_CATEGORIES, - }), - }, - { type: 'BoardImagesTotal', id: 'none' }, - { type: 'BoardAssetsTotal', id: 'none' }, - ], - async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) { - /** - * Cache changes for deleteBoardAndImages: - * - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~ - * This isn't actually possible, you cannot remove cache entries with RTK Query. - * Instead, we rely on the UI to remove all components that use the deleted images. - * - Remove every image in the 'All Images' cache that has the board_id - * - Remove every image in the 'All Assets' cache that has the board_id - */ - - try { - const { data } = await queryFulfilled; - const { deleted_images } = data; - - // update 'All Images' & 'All Assets' caches - const queryArgsToUpdate = [ - { - categories: IMAGE_CATEGORIES, - }, - { - categories: ASSETS_CATEGORIES, - }, - ]; - - queryArgsToUpdate.forEach((queryArgs) => { - dispatch( - imagesApi.util.updateQueryData( - 'listImages', - queryArgs, - (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.removeMany( - draft, - deleted_images - ); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; - } - ) - ); - }); - } catch { - //no-op - } - }, - }), }), }); export const { useListBoardsQuery, useListAllBoardsQuery, + useGetBoardImagesTotalQuery, + useGetBoardAssetsTotalQuery, useCreateBoardMutation, useUpdateBoardMutation, - useDeleteBoardMutation, - useDeleteBoardAndImagesMutation, useListAllImageNamesForBoardQuery, } = boardsApi; diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index e8740a418b..e093c1c33a 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -1,93 +1,37 @@ -import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; +import { EntityState, Update } from '@reduxjs/toolkit'; import { PatchCollection } from '@reduxjs/toolkit/dist/query/core/buildThunks'; -import { dateComparator } from 'common/util/dateComparator'; import { ASSETS_CATEGORIES, BoardId, IMAGE_CATEGORIES, } from 'features/gallery/store/types'; -import queryString from 'query-string'; -import { ApiFullTagDescription, api } from '..'; -import { components, paths } from '../schema'; +import { keyBy } from 'lodash'; +import { ApiFullTagDescription, LIST_TAG, api } from '..'; +import { components } from '../schema'; import { + DeleteBoardResult, ImageCategory, ImageDTO, + ListImagesArgs, OffsetPaginatedResults_ImageDTO_, PostUploadAction, + UnsafeImageMetadata, } from '../types'; - -const getIsImageInDateRange = ( - data: ImageCache | undefined, - imageDTO: ImageDTO -) => { - if (!data) { - return false; - } - const cacheImageDTOS = imagesSelectors.selectAll(data); - - if (cacheImageDTOS.length > 1) { - // Images are sorted by `created_at` DESC - // check if the image is newer than the oldest image in the cache - const createdDate = new Date(imageDTO.created_at); - const oldestDate = new Date( - cacheImageDTOS[cacheImageDTOS.length - 1].created_at - ); - return createdDate >= oldestDate; - } else if ([0, 1].includes(cacheImageDTOS.length)) { - // if there are only 1 or 0 images in the cache, we consider the image to be in the date range - return true; - } - return false; -}; - -const getCategories = (imageDTO: ImageDTO) => { - if (IMAGE_CATEGORIES.includes(imageDTO.image_category)) { - return IMAGE_CATEGORIES; - } - return ASSETS_CATEGORIES; -}; - -export type ListImagesArgs = NonNullable< - paths['/api/v1/images/']['get']['parameters']['query'] ->; - -/** - * This is an unsafe type; the object inside is not guaranteed to be valid. - */ -export type UnsafeImageMetadata = { - metadata: components['schemas']['CoreMetadata']; - graph: NonNullable; -}; - -export type ImageCache = EntityState & { total: number }; - -// The adapter is not actually the data store - it just provides helper functions to interact -// with some other store of data. We will use the RTK Query cache as that store. -export const imagesAdapter = createEntityAdapter({ - selectId: (image) => image.image_name, - sortComparer: (a, b) => dateComparator(b.updated_at, a.updated_at), -}); - -// We want to also store the images total in the cache. When we initialize the cache state, -// we will provide this type arg so the adapter knows we want the total. -export type AdditionalImagesAdapterState = { total: number }; - -// Create selectors for the adapter. -export const imagesSelectors = imagesAdapter.getSelectors(); - -// Helper to create the url for the listImages endpoint. Also we use it to create the cache key. -export const getListImagesUrl = (queryArgs: ListImagesArgs) => - `images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`; +import { + getCategories, + getIsImageInDateRange, + getListImagesUrl, + imagesAdapter, + imagesSelectors, +} from '../util'; +import { boardsApi } from './boards'; export const imagesApi = api.injectEndpoints({ endpoints: (build) => ({ /** * Image Queries */ - listImages: build.query< - EntityState & { total: number }, - ListImagesArgs - >({ + listImages: build.query, ListImagesArgs>({ query: (queryArgs) => ({ // Use the helper to create the URL. url: getListImagesUrl(queryArgs), @@ -110,23 +54,17 @@ export const imagesApi = api.injectEndpoints({ return cacheKey; }, transformResponse(response: OffsetPaginatedResults_ImageDTO_) { - const { total, items: images } = response; - // Use the adapter to convert the response to the right shape, and adding the new total. + const { items: images } = response; + // Use the adapter to convert the response to the right shape. // The trick is to just provide an empty state and add the images array to it. This returns // a properly shaped EntityState. - return imagesAdapter.addMany( - imagesAdapter.getInitialState({ - total, - }), - images - ); + return imagesAdapter.addMany(imagesAdapter.getInitialState(), images); }, merge: (cache, response) => { // Here we actually update the cache. `response` here is the output of `transformResponse` // above. In a similar vein to `transformResponse`, we can use the imagesAdapter to get - // things in the right shape. Also update the total image count. + // things in the right shape. imagesAdapter.addMany(cache, imagesSelectors.selectAll(response)); - cache.total = response.total; }, forceRefetch({ currentArg, previousArg }) { // Refetch when the offset changes (which means we are on a new page). @@ -161,69 +99,26 @@ export const imagesApi = api.injectEndpoints({ }, }), getImageDTO: build.query({ - query: (image_name) => ({ url: `images/${image_name}` }), - providesTags: (result, error, arg) => { - const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }]; - if (result?.board_id) { - tags.push({ type: 'Board', id: result.board_id }); - } - return tags; - }, + query: (image_name) => ({ url: `images/i/${image_name}` }), + providesTags: (result, error, image_name) => [ + { type: 'Image', id: image_name }, + ], keepUnusedDataFor: 86400, // 24 hours }), getImageMetadata: build.query({ - query: (image_name) => ({ url: `images/${image_name}/metadata` }), - providesTags: (result, error, arg) => { - const tags: ApiFullTagDescription[] = [ - { type: 'ImageMetadata', id: arg }, - ]; - return tags; - }, + query: (image_name) => ({ url: `images/i/${image_name}/metadata` }), + providesTags: (result, error, image_name) => [ + { type: 'ImageMetadata', id: image_name }, + ], keepUnusedDataFor: 86400, // 24 hours }), - getBoardImagesTotal: build.query({ - query: (board_id) => ({ - url: getListImagesUrl({ - board_id: board_id ?? 'none', - categories: IMAGE_CATEGORIES, - is_intermediate: false, - limit: 0, - offset: 0, - }), - method: 'GET', - }), - providesTags: (result, error, arg) => [ - { type: 'BoardImagesTotal', id: arg ?? 'none' }, - ], - transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => { - return response.total; - }, - }), - getBoardAssetsTotal: build.query({ - query: (board_id) => ({ - url: getListImagesUrl({ - board_id: board_id ?? 'none', - categories: ASSETS_CATEGORIES, - is_intermediate: false, - limit: 0, - offset: 0, - }), - method: 'GET', - }), - providesTags: (result, error, arg) => [ - { type: 'BoardAssetsTotal', id: arg ?? 'none' }, - ], - transformResponse: (response: OffsetPaginatedResults_ImageDTO_) => { - return response.total; - }, - }), clearIntermediates: build.mutation({ query: () => ({ url: `images/clear-intermediates`, method: 'POST' }), invalidatesTags: ['IntermediatesCount'], }), deleteImage: build.mutation({ query: ({ image_name }) => ({ - url: `images/${image_name}`, + url: `images/i/${image_name}`, method: 'DELETE', }), invalidatesTags: (result, error, { board_id }) => [ @@ -240,33 +135,77 @@ export const imagesApi = api.injectEndpoints({ const { image_name, board_id } = imageDTO; - // Store patches so we can undo if the query fails - const patches: PatchCollection[] = []; + const queryArg = { + board_id: board_id ?? 'none', + categories: getCategories(imageDTO), + }; - // determine `categories`, i.e. do we update "All Images" or "All Assets" - // $cache = [board_id|no_board]/[images|assets] - const categories = getCategories(imageDTO); - - // *remove* from $cache - patches.push( - dispatch( - imagesApi.util.updateQueryData( - 'listImages', - { board_id: board_id ?? 'none', categories }, - (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.removeOne(draft, image_name); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; - } - ) - ) + const patch = dispatch( + imagesApi.util.updateQueryData('listImages', queryArg, (draft) => { + imagesAdapter.removeOne(draft, image_name); + }) ); try { await queryFulfilled; } catch { - patches.forEach((patchResult) => patchResult.undo()); + patch.undo(); + } + }, + }), + deleteImages: build.mutation< + components['schemas']['DeleteImagesFromListResult'], + { imageDTOs: ImageDTO[] } + >({ + query: ({ imageDTOs }) => { + const image_names = imageDTOs.map((imageDTO) => imageDTO.image_name); + return { + url: `images/delete`, + method: 'POST', + body: { + image_names, + }, + }; + }, + invalidatesTags: (result, error, imageDTOs) => [], + async onQueryStarted({ imageDTOs }, { dispatch, queryFulfilled }) { + /** + * Cache changes for `deleteImages`: + * - *remove* the deleted images from their boards + * + * Unfortunately we cannot do an optimistic update here due to how immer handles patching + * arrays. You have to undo *all* patches, else the entity adapter's `ids` array is borked. + * So we have to wait for the query to complete before updating the cache. + */ + try { + const { data } = await queryFulfilled; + + // convert to an object so we can access the successfully delete image DTOs by name + const groupedImageDTOs = keyBy(imageDTOs, 'image_name'); + + data.deleted_images.forEach((image_name) => { + const imageDTO = groupedImageDTOs[image_name]; + + // should never be undefined + if (imageDTO) { + const queryArg = { + board_id: imageDTO.board_id ?? 'none', + categories: getCategories(imageDTO), + }; + // remove all deleted images from their boards + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + queryArg, + (draft) => { + imagesAdapter.removeOne(draft, image_name); + } + ) + ); + } + }); + } catch { + // } }, }), @@ -278,7 +217,7 @@ export const imagesApi = api.injectEndpoints({ { imageDTO: ImageDTO; is_intermediate: boolean } >({ query: ({ imageDTO, is_intermediate }) => ({ - url: `images/${imageDTO.image_name}`, + url: `images/i/${imageDTO.image_name}`, method: 'PATCH', body: { is_intermediate }, }), @@ -329,20 +268,13 @@ export const imagesApi = api.injectEndpoints({ 'listImages', { board_id: imageDTO.board_id ?? 'none', categories }, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.removeOne( - draft, - imageDTO.image_name - ); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.removeOne(draft, imageDTO.image_name); } ) ) ); } else { // ELSE (it is being changed to a non-intermediate): - console.log(imageDTO); const queryArgs = { board_id: imageDTO.board_id ?? 'none', categories, @@ -352,6 +284,16 @@ export const imagesApi = api.injectEndpoints({ getState() ); + const { data: total } = IMAGE_CATEGORIES.includes( + imageDTO.image_category + ) + ? boardsApi.endpoints.getBoardImagesTotal.select( + imageDTO.board_id ?? 'none' + )(getState()) + : boardsApi.endpoints.getBoardAssetsTotal.select( + imageDTO.board_id ?? 'none' + )(getState()); + // IF it eligible for insertion into existing $cache // "eligible" means either: // - The cache is fully populated, with all images in the db cached @@ -359,8 +301,7 @@ export const imagesApi = api.injectEndpoints({ // - The image's `created_at` is within the range of the cached images const isCacheFullyPopulated = - currentCache.data && - currentCache.data.ids.length >= currentCache.data.total; + currentCache.data && currentCache.data.ids.length >= (total ?? 0); const isInDateRange = getIsImageInDateRange( currentCache.data, @@ -375,10 +316,7 @@ export const imagesApi = api.injectEndpoints({ 'listImages', queryArgs, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.upsertOne(draft, imageDTO); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.upsertOne(draft, imageDTO); } ) ) @@ -401,7 +339,7 @@ export const imagesApi = api.injectEndpoints({ { imageDTO: ImageDTO; session_id: string } >({ query: ({ imageDTO, session_id }) => ({ - url: `images/${imageDTO.image_name}`, + url: `images/i/${imageDTO.image_name}`, method: 'PATCH', body: { session_id }, }), @@ -464,14 +402,14 @@ export const imagesApi = api.injectEndpoints({ const formData = new FormData(); formData.append('file', file); return { - url: `images/`, + url: `images/upload`, method: 'POST', body: formData, params: { image_category, is_intermediate, session_id, - board_id, + board_id: board_id === 'none' ? undefined : board_id, crop_visible, }, }; @@ -524,10 +462,7 @@ export const imagesApi = api.injectEndpoints({ categories, }, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.addOne(draft, imageDTO); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.addOne(draft, imageDTO); } ) ); @@ -543,6 +478,158 @@ export const imagesApi = api.injectEndpoints({ } }, }), + + deleteBoard: build.mutation({ + query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), + invalidatesTags: (result, error, board_id) => [ + { type: 'Board', id: LIST_TAG }, + // invalidate the 'No Board' cache + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: 'none', + categories: IMAGE_CATEGORIES, + }), + }, + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: 'none', + categories: ASSETS_CATEGORIES, + }), + }, + { type: 'BoardImagesTotal', id: 'none' }, + { type: 'BoardAssetsTotal', id: 'none' }, + ], + async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) { + /** + * Cache changes for deleteBoard: + * - Update every image in the 'getImageDTO' cache that has the board_id + * - Update every image in the 'All Images' cache that has the board_id + * - Update every image in the 'All Assets' cache that has the board_id + * - Invalidate the 'No Board' cache: + * Ideally we'd be able to insert all deleted images into the cache, but we don't + * have access to the deleted images DTOs - only the names, and a network request + * for all of a board's DTOs could be very large. Instead, we invalidate the 'No Board' + * cache. + */ + + try { + const { data } = await queryFulfilled; + const { deleted_board_images } = data; + + // update getImageDTO caches + deleted_board_images.forEach((image_id) => { + dispatch( + imagesApi.util.updateQueryData( + 'getImageDTO', + image_id, + (draft) => { + draft.board_id = undefined; + } + ) + ); + }); + + // update 'All Images' & 'All Assets' caches + const queryArgsToUpdate = [ + { + categories: IMAGE_CATEGORIES, + }, + { + categories: ASSETS_CATEGORIES, + }, + ]; + + const updates: Update[] = deleted_board_images.map( + (image_name) => ({ + id: image_name, + changes: { board_id: undefined }, + }) + ); + + queryArgsToUpdate.forEach((queryArgs) => { + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + queryArgs, + (draft) => { + imagesAdapter.updateMany(draft, updates); + } + ) + ); + }); + } catch { + //no-op + } + }, + }), + + deleteBoardAndImages: build.mutation({ + query: (board_id) => ({ + url: `boards/${board_id}`, + method: 'DELETE', + params: { include_images: true }, + }), + invalidatesTags: (result, error, board_id) => [ + { type: 'Board', id: LIST_TAG }, + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: 'none', + categories: IMAGE_CATEGORIES, + }), + }, + { + type: 'ImageList', + id: getListImagesUrl({ + board_id: 'none', + categories: ASSETS_CATEGORIES, + }), + }, + { type: 'BoardImagesTotal', id: 'none' }, + { type: 'BoardAssetsTotal', id: 'none' }, + ], + async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) { + /** + * Cache changes for deleteBoardAndImages: + * - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~ + * This isn't actually possible, you cannot remove cache entries with RTK Query. + * Instead, we rely on the UI to remove all components that use the deleted images. + * - Remove every image in the 'All Images' cache that has the board_id + * - Remove every image in the 'All Assets' cache that has the board_id + */ + + try { + const { data } = await queryFulfilled; + const { deleted_images } = data; + + // update 'All Images' & 'All Assets' caches + const queryArgsToUpdate = [ + { + categories: IMAGE_CATEGORIES, + }, + { + categories: ASSETS_CATEGORIES, + }, + ]; + + queryArgsToUpdate.forEach((queryArgs) => { + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + queryArgs, + (draft) => { + imagesAdapter.removeMany(draft, deleted_images); + } + ) + ); + }); + } catch { + //no-op + } + }, + }), addImageToBoard: build.mutation< void, { board_id: BoardId; imageDTO: ImageDTO } @@ -556,10 +643,13 @@ export const imagesApi = api.injectEndpoints({ }; }, invalidatesTags: (result, error, { board_id, imageDTO }) => [ + // refresh the board itself { type: 'Board', id: board_id }, + // update old board totals { type: 'BoardImagesTotal', id: board_id }, - { type: 'BoardImagesTotal', id: imageDTO.board_id ?? 'none' }, { type: 'BoardAssetsTotal', id: board_id }, + // update new board totals + { type: 'BoardImagesTotal', id: imageDTO.board_id ?? 'none' }, { type: 'BoardAssetsTotal', id: imageDTO.board_id ?? 'none' }, ], async onQueryStarted( @@ -589,7 +679,7 @@ export const imagesApi = api.injectEndpoints({ 'getImageDTO', imageDTO.image_name, (draft) => { - Object.assign(draft, { board_id }); + draft.board_id = board_id; } ) ) @@ -606,13 +696,7 @@ export const imagesApi = api.injectEndpoints({ categories, }, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.removeOne( - draft, - imageDTO.image_name - ); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.removeOne(draft, imageDTO.image_name); } ) ) @@ -630,9 +714,18 @@ export const imagesApi = api.injectEndpoints({ // OR // - The image's `created_at` is within the range of the cached images + const { data: total } = IMAGE_CATEGORIES.includes( + imageDTO.image_category + ) + ? boardsApi.endpoints.getBoardImagesTotal.select( + imageDTO.board_id ?? 'none' + )(getState()) + : boardsApi.endpoints.getBoardAssetsTotal.select( + imageDTO.board_id ?? 'none' + )(getState()); + const isCacheFullyPopulated = - currentCache.data && - currentCache.data.ids.length >= currentCache.data.total; + currentCache.data && currentCache.data.ids.length >= (total ?? 0); const isInDateRange = getIsImageInDateRange( currentCache.data, @@ -647,10 +740,7 @@ export const imagesApi = api.injectEndpoints({ 'listImages', queryArgs, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.addOne(draft, imageDTO); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.addOne(draft, imageDTO); } ) ) @@ -667,20 +757,26 @@ export const imagesApi = api.injectEndpoints({ }), removeImageFromBoard: build.mutation({ query: ({ imageDTO }) => { - const { board_id, image_name } = imageDTO; + const { image_name } = imageDTO; return { url: `board_images/`, method: 'DELETE', - body: { board_id, image_name }, + body: { image_name }, }; }, - invalidatesTags: (result, error, { imageDTO }) => [ - { type: 'Board', id: imageDTO.board_id }, - { type: 'BoardImagesTotal', id: imageDTO.board_id }, - { type: 'BoardImagesTotal', id: 'none' }, - { type: 'BoardAssetsTotal', id: imageDTO.board_id }, - { type: 'BoardAssetsTotal', id: 'none' }, - ], + invalidatesTags: (result, error, { imageDTO }) => { + const { board_id } = imageDTO; + return [ + // invalidate the image's old board + { type: 'Board', id: board_id ?? 'none' }, + // update old board totals + { type: 'BoardImagesTotal', id: board_id ?? 'none' }, + { type: 'BoardAssetsTotal', id: board_id ?? 'none' }, + // update the no_board totals + { type: 'BoardImagesTotal', id: 'none' }, + { type: 'BoardAssetsTotal', id: 'none' }, + ]; + }, async onQueryStarted( { imageDTO }, { dispatch, queryFulfilled, getState } @@ -704,7 +800,7 @@ export const imagesApi = api.injectEndpoints({ 'getImageDTO', imageDTO.image_name, (draft) => { - Object.assign(draft, { board_id: undefined }); + draft.board_id = undefined; } ) ) @@ -720,13 +816,7 @@ export const imagesApi = api.injectEndpoints({ categories, }, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.removeOne( - draft, - imageDTO.image_name - ); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.removeOne(draft, imageDTO.image_name); } ) ) @@ -744,9 +834,18 @@ export const imagesApi = api.injectEndpoints({ // OR // - The image's `created_at` is within the range of the cached images + const { data: total } = IMAGE_CATEGORIES.includes( + imageDTO.image_category + ) + ? boardsApi.endpoints.getBoardImagesTotal.select( + imageDTO.board_id ?? 'none' + )(getState()) + : boardsApi.endpoints.getBoardAssetsTotal.select( + imageDTO.board_id ?? 'none' + )(getState()); + const isCacheFullyPopulated = - currentCache.data && - currentCache.data.ids.length >= currentCache.data.total; + currentCache.data && currentCache.data.ids.length >= (total ?? 0); const isInDateRange = getIsImageInDateRange( currentCache.data, @@ -761,10 +860,7 @@ export const imagesApi = api.injectEndpoints({ 'listImages', queryArgs, (draft) => { - const oldTotal = draft.total; - const newState = imagesAdapter.upsertOne(draft, imageDTO); - const delta = newState.total - oldTotal; - draft.total = draft.total + delta; + imagesAdapter.upsertOne(draft, imageDTO); } ) ) @@ -778,6 +874,255 @@ export const imagesApi = api.injectEndpoints({ } }, }), + addImagesToBoard: build.mutation< + components['schemas']['AddImagesToBoardResult'], + { + board_id: string; + imageDTOs: ImageDTO[]; + } + >({ + query: ({ board_id, imageDTOs }) => ({ + url: `board_images/batch`, + method: 'POST', + body: { + image_names: imageDTOs.map((i) => i.image_name), + board_id, + }, + }), + invalidatesTags: (result, error, { board_id }) => [ + // update the destination board + { type: 'Board', id: board_id ?? 'none' }, + // update old board totals + { type: 'BoardImagesTotal', id: board_id ?? 'none' }, + { type: 'BoardAssetsTotal', id: board_id ?? 'none' }, + // update the no_board totals + { type: 'BoardImagesTotal', id: 'none' }, + { type: 'BoardAssetsTotal', id: 'none' }, + ], + async onQueryStarted( + { board_id, imageDTOs }, + { dispatch, queryFulfilled, getState } + ) { + try { + const { data } = await queryFulfilled; + const { added_image_names } = data; + + /** + * Cache changes for addImagesToBoard: + * - *update* getImageDTO for each image + * - *add* to board_id/[images|assets] + * - *remove* from [old_board_id|no_board]/[images|assets] + */ + + added_image_names.forEach((image_name) => { + dispatch( + imagesApi.util.updateQueryData( + 'getImageDTO', + image_name, + (draft) => { + draft.board_id = board_id; + } + ) + ); + + const imageDTO = imageDTOs.find((i) => i.image_name === image_name); + + if (!imageDTO) { + return; + } + + const categories = getCategories(imageDTO); + const old_board_id = imageDTO.board_id; + + // remove from the old board + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + { board_id: old_board_id ?? 'none', categories }, + (draft) => { + imagesAdapter.removeOne(draft, imageDTO.image_name); + } + ) + ); + + const queryArgs = { + board_id, + categories, + }; + + const currentCache = imagesApi.endpoints.listImages.select( + queryArgs + )(getState()); + + const { data: total } = IMAGE_CATEGORIES.includes( + imageDTO.image_category + ) + ? boardsApi.endpoints.getBoardImagesTotal.select( + imageDTO.board_id ?? 'none' + )(getState()) + : boardsApi.endpoints.getBoardAssetsTotal.select( + imageDTO.board_id ?? 'none' + )(getState()); + + const isCacheFullyPopulated = + currentCache.data && currentCache.data.ids.length >= (total ?? 0); + + const isInDateRange = getIsImageInDateRange( + currentCache.data, + imageDTO + ); + + if (isCacheFullyPopulated || isInDateRange) { + // *upsert* to $cache + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + queryArgs, + (draft) => { + imagesAdapter.upsertOne(draft, { + ...imageDTO, + board_id, + }); + } + ) + ); + } + }); + } catch { + // no-op + } + }, + }), + removeImagesFromBoard: build.mutation< + components['schemas']['RemoveImagesFromBoardResult'], + { + imageDTOs: ImageDTO[]; + } + >({ + query: ({ imageDTOs }) => ({ + url: `board_images/batch/delete`, + method: 'POST', + body: { + image_names: imageDTOs.map((i) => i.image_name), + }, + }), + invalidatesTags: (result, error, { imageDTOs }) => { + const touchedBoardIds: string[] = []; + const tags: ApiFullTagDescription[] = [ + { type: 'BoardImagesTotal', id: 'none' }, + { type: 'BoardAssetsTotal', id: 'none' }, + ]; + + result?.removed_image_names.forEach((image_name) => { + const board_id = imageDTOs.find( + (i) => i.image_name === image_name + )?.board_id; + + if (!board_id || touchedBoardIds.includes(board_id)) { + return; + } + + tags.push({ type: 'Board', id: board_id }); + tags.push({ type: 'BoardImagesTotal', id: board_id }); + tags.push({ type: 'BoardAssetsTotal', id: board_id }); + }); + + return tags; + }, + async onQueryStarted( + { imageDTOs }, + { dispatch, queryFulfilled, getState } + ) { + try { + const { data } = await queryFulfilled; + const { removed_image_names } = data; + + /** + * Cache changes for removeImagesFromBoard: + * - *update* getImageDTO for each image + * - *remove* from old_board_id/[images|assets] + * - *add* to no_board/[images|assets] + */ + + removed_image_names.forEach((image_name) => { + dispatch( + imagesApi.util.updateQueryData( + 'getImageDTO', + image_name, + (draft) => { + draft.board_id = undefined; + } + ) + ); + + const imageDTO = imageDTOs.find((i) => i.image_name === image_name); + + if (!imageDTO) { + return; + } + + const categories = getCategories(imageDTO); + + // remove from the old board + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + { board_id: imageDTO.board_id ?? 'none', categories }, + (draft) => { + imagesAdapter.removeOne(draft, imageDTO.image_name); + } + ) + ); + + // add to `no_board` + const queryArgs = { + board_id: 'none', + categories, + }; + + const currentCache = imagesApi.endpoints.listImages.select( + queryArgs + )(getState()); + + const { data: total } = IMAGE_CATEGORIES.includes( + imageDTO.image_category + ) + ? boardsApi.endpoints.getBoardImagesTotal.select( + imageDTO.board_id ?? 'none' + )(getState()) + : boardsApi.endpoints.getBoardAssetsTotal.select( + imageDTO.board_id ?? 'none' + )(getState()); + + const isCacheFullyPopulated = + currentCache.data && currentCache.data.ids.length >= (total ?? 0); + + const isInDateRange = getIsImageInDateRange( + currentCache.data, + imageDTO + ); + + if (isCacheFullyPopulated || isInDateRange) { + // *upsert* to $cache + dispatch( + imagesApi.util.updateQueryData( + 'listImages', + queryArgs, + (draft) => { + imagesAdapter.upsertOne(draft, { + ...imageDTO, + board_id: undefined, + }); + } + ) + ); + } + }); + } catch { + // no-op + } + }, + }), }), }); @@ -788,10 +1133,15 @@ export const { useGetImageDTOQuery, useGetImageMetadataQuery, useDeleteImageMutation, - useGetBoardImagesTotalQuery, - useGetBoardAssetsTotalQuery, + useDeleteImagesMutation, useUploadImageMutation, + useClearIntermediatesMutation, + useAddImagesToBoardMutation, + useRemoveImagesFromBoardMutation, useAddImageToBoardMutation, useRemoveImageFromBoardMutation, - useClearIntermediatesMutation, + useChangeImageIsIntermediateMutation, + useChangeImageSessionIdMutation, + useDeleteBoardAndImagesMutation, + useDeleteBoardMutation, } = imagesApi; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index a7b1323f36..33eb1fbdc2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -5,7 +5,6 @@ import { BaseModelType, CheckpointModelConfig, ControlNetModelConfig, - ConvertModelConfig, DiffusersModelConfig, ImportModelConfig, LoRAModelConfig, @@ -83,7 +82,7 @@ type DeleteLoRAModelResponse = void; type ConvertMainModelArg = { base_model: BaseModelType; model_name: string; - params: ConvertModelConfig; + convert_dest_directory?: string; }; type ConvertMainModelResponse = @@ -122,7 +121,7 @@ type CheckpointConfigsResponse = type SearchFolderArg = operations['search_for_models']['parameters']['query']; -const mainModelsAdapter = createEntityAdapter({ +export const mainModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); @@ -132,15 +131,15 @@ const onnxModelsAdapter = createEntityAdapter({ const loraModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); -const controlNetModelsAdapter = +export const controlNetModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); -const textualInversionModelsAdapter = +export const textualInversionModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); -const vaeModelsAdapter = createEntityAdapter({ +export const vaeModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), }); @@ -320,11 +319,11 @@ export const modelsApi = api.injectEndpoints({ ConvertMainModelResponse, ConvertMainModelArg >({ - query: ({ base_model, model_name, params }) => { + query: ({ base_model, model_name, convert_dest_directory }) => { return { url: `models/convert/${base_model}/main/${model_name}`, method: 'PUT', - params: params, + params: { convert_dest_directory }, }; }, invalidatesTags: [ diff --git a/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts b/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts index 748f2c8f6e..ce0cff7b8a 100644 --- a/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts +++ b/invokeai/frontend/web/src/services/api/hooks/useBoardName.ts @@ -1,7 +1,7 @@ import { BoardId } from 'features/gallery/store/types'; import { useListAllBoardsQuery } from '../endpoints/boards'; -export const useBoardName = (board_id: BoardId | null | undefined) => { +export const useBoardName = (board_id: BoardId) => { const { boardName } = useListAllBoardsQuery(undefined, { selectFromResult: ({ data }) => { const selectedBoard = data?.find((b) => b.board_id === board_id); diff --git a/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts b/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts index dd144ffe00..a350979b89 100644 --- a/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts +++ b/invokeai/frontend/web/src/services/api/hooks/useBoardTotal.ts @@ -4,7 +4,7 @@ import { useMemo } from 'react'; import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery, -} from '../endpoints/images'; +} from '../endpoints/boards'; export const useBoardTotal = (board_id: BoardId) => { const galleryView = useAppSelector((state) => state.gallery.galleryView); diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 80f0933f37..6574ec4909 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -135,19 +135,14 @@ export type paths = { */ put: operations["merge_models"]; }; - "/api/v1/images/": { - /** - * List Image Dtos - * @description Gets a list of image DTOs - */ - get: operations["list_image_dtos"]; + "/api/v1/images/upload": { /** * Upload Image * @description Uploads an image */ post: operations["upload_image"]; }; - "/api/v1/images/{image_name}": { + "/api/v1/images/i/{image_name}": { /** * Get Image Dto * @description Gets an image's DTO @@ -171,34 +166,45 @@ export type paths = { */ post: operations["clear_intermediates"]; }; - "/api/v1/images/{image_name}/metadata": { + "/api/v1/images/i/{image_name}/metadata": { /** * Get Image Metadata * @description Gets an image's metadata */ get: operations["get_image_metadata"]; }; - "/api/v1/images/{image_name}/full": { + "/api/v1/images/i/{image_name}/full": { /** * Get Image Full * @description Gets a full-resolution image file */ get: operations["get_image_full"]; }; - "/api/v1/images/{image_name}/thumbnail": { + "/api/v1/images/i/{image_name}/thumbnail": { /** * Get Image Thumbnail * @description Gets a thumbnail image file */ get: operations["get_image_thumbnail"]; }; - "/api/v1/images/{image_name}/urls": { + "/api/v1/images/i/{image_name}/urls": { /** * Get Image Urls * @description Gets an image and thumbnail URL */ get: operations["get_image_urls"]; }; + "/api/v1/images/": { + /** + * List Image Dtos + * @description Gets a list of image DTOs + */ + get: operations["list_image_dtos"]; + }; + "/api/v1/images/delete": { + /** Delete Images From List */ + post: operations["delete_images_from_list"]; + }; "/api/v1/boards/": { /** * List Boards @@ -237,15 +243,29 @@ export type paths = { }; "/api/v1/board_images/": { /** - * Create Board Image + * Add Image To Board * @description Creates a board_image */ - post: operations["create_board_image"]; + post: operations["add_image_to_board"]; /** - * Remove Board Image - * @description Deletes a board_image + * Remove Image From Board + * @description Removes an image from its board, if it had one */ - delete: operations["remove_board_image"]; + delete: operations["remove_image_from_board"]; + }; + "/api/v1/board_images/batch": { + /** + * Add Images To Board + * @description Adds a list of images to a board + */ + post: operations["add_images_to_board"]; + }; + "/api/v1/board_images/batch/delete": { + /** + * Remove Images From Board + * @description Removes a list of images from their board, if they had one + */ + post: operations["remove_images_from_board"]; }; "/api/v1/app/version": { /** Get Version */ @@ -273,6 +293,19 @@ export type webhooks = Record; export type components = { schemas: { + /** AddImagesToBoardResult */ + AddImagesToBoardResult: { + /** + * Board Id + * @description The id of the board the images were added to + */ + board_id: string; + /** + * Added Image Names + * @description The image names that were added to the board + */ + added_image_names: (string)[]; + }; /** * AddInvocation * @description Adds two numbers @@ -405,8 +438,8 @@ export type components = { */ image_count: number; }; - /** Body_create_board_image */ - Body_create_board_image: { + /** Body_add_image_to_board */ + Body_add_image_to_board: { /** * Board Id * @description The id of the board to add to @@ -418,6 +451,27 @@ export type components = { */ image_name: string; }; + /** Body_add_images_to_board */ + Body_add_images_to_board: { + /** + * Board Id + * @description The id of the board to add to + */ + board_id: string; + /** + * Image Names + * @description The names of the images to add + */ + image_names: (string)[]; + }; + /** Body_delete_images_from_list */ + Body_delete_images_from_list: { + /** + * Image Names + * @description The list of names of images to delete + */ + image_names: (string)[]; + }; /** Body_import_model */ Body_import_model: { /** @@ -465,19 +519,22 @@ export type components = { */ merge_dest_directory?: string; }; - /** Body_remove_board_image */ - Body_remove_board_image: { - /** - * Board Id - * @description The id of the board - */ - board_id: string; + /** Body_remove_image_from_board */ + Body_remove_image_from_board: { /** * Image Name * @description The name of the image to remove */ image_name: string; }; + /** Body_remove_images_from_board */ + Body_remove_images_from_board: { + /** + * Image Names + * @description The names of the images to remove + */ + image_names: (string)[]; + }; /** Body_upload_image */ Body_upload_image: { /** @@ -1157,6 +1214,11 @@ export type components = { */ deleted_images: (string)[]; }; + /** DeleteImagesFromListResult */ + DeleteImagesFromListResult: { + /** Deleted Images */ + deleted_images: (string)[]; + }; /** * DivideInvocation * @description Divides two numbers @@ -4627,6 +4689,14 @@ export type components = { */ step?: number; }; + /** RemoveImagesFromBoardResult */ + RemoveImagesFromBoardResult: { + /** + * Removed Image Names + * @description The image names that were removed from their board + */ + removed_image_names: (string)[]; + }; /** * ResizeLatentsInvocation * @description Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8. @@ -5891,18 +5961,6 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusionXLModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionOnnxModelFormat * @description An enumeration. @@ -5921,6 +5979,18 @@ export type components = { * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusionXLModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + /** + * ControlNetModelFormat + * @description An enumeration. + * @enum {string} + */ + ControlNetModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; @@ -6547,42 +6617,6 @@ export type operations = { }; }; }; - /** - * List Image Dtos - * @description Gets a list of image DTOs - */ - list_image_dtos: { - parameters: { - query?: { - /** @description The origin of images to list. */ - image_origin?: components["schemas"]["ResourceOrigin"]; - /** @description The categories of image to include. */ - categories?: (components["schemas"]["ImageCategory"])[]; - /** @description Whether to list intermediate images. */ - is_intermediate?: boolean; - /** @description The board id to filter by. Use 'none' to find images without a board. */ - board_id?: string; - /** @description The page offset */ - offset?: number; - /** @description The number of images per page */ - limit?: number; - }; - }; - responses: { - /** @description Successful Response */ - 200: { - content: { - "application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"]; - }; - }; - /** @description Validation Error */ - 422: { - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; /** * Upload Image * @description Uploads an image @@ -6829,6 +6863,64 @@ export type operations = { }; }; }; + /** + * List Image Dtos + * @description Gets a list of image DTOs + */ + list_image_dtos: { + parameters: { + query?: { + /** @description The origin of images to list. */ + image_origin?: components["schemas"]["ResourceOrigin"]; + /** @description The categories of image to include. */ + categories?: (components["schemas"]["ImageCategory"])[]; + /** @description Whether to list intermediate images. */ + is_intermediate?: boolean; + /** @description The board id to filter by. Use 'none' to find images without a board. */ + board_id?: string; + /** @description The page offset */ + offset?: number; + /** @description The number of images per page */ + limit?: number; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** Delete Images From List */ + delete_images_from_list: { + requestBody: { + content: { + "application/json": components["schemas"]["Body_delete_images_from_list"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + content: { + "application/json": components["schemas"]["DeleteImagesFromListResult"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * List Boards * @description Gets a list of boards @@ -6999,13 +7091,13 @@ export type operations = { }; }; /** - * Create Board Image + * Add Image To Board * @description Creates a board_image */ - create_board_image: { + add_image_to_board: { requestBody: { content: { - "application/json": components["schemas"]["Body_create_board_image"]; + "application/json": components["schemas"]["Body_add_image_to_board"]; }; }; responses: { @@ -7024,13 +7116,13 @@ export type operations = { }; }; /** - * Remove Board Image - * @description Deletes a board_image + * Remove Image From Board + * @description Removes an image from its board, if it had one */ - remove_board_image: { + remove_image_from_board: { requestBody: { content: { - "application/json": components["schemas"]["Body_remove_board_image"]; + "application/json": components["schemas"]["Body_remove_image_from_board"]; }; }; responses: { @@ -7048,6 +7140,56 @@ export type operations = { }; }; }; + /** + * Add Images To Board + * @description Adds a list of images to a board + */ + add_images_to_board: { + requestBody: { + content: { + "application/json": components["schemas"]["Body_add_images_to_board"]; + }; + }; + responses: { + /** @description Images were added to board successfully */ + 201: { + content: { + "application/json": components["schemas"]["AddImagesToBoardResult"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + /** + * Remove Images From Board + * @description Removes a list of images from their board, if they had one + */ + remove_images_from_board: { + requestBody: { + content: { + "application/json": components["schemas"]["Body_remove_images_from_board"]; + }; + }; + responses: { + /** @description Images were removed from board successfully */ + 201: { + content: { + "application/json": components["schemas"]["RemoveImagesFromBoardResult"]; + }; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** Get Version */ app_version: { responses: { diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.ts similarity index 89% rename from invokeai/frontend/web/src/services/api/types.d.ts rename to invokeai/frontend/web/src/services/api/types.ts index 2ee508fe48..ca9dbb3aeb 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -1,13 +1,40 @@ import { UseToastOptions } from '@chakra-ui/react'; +import { EntityState } from '@reduxjs/toolkit'; import { O } from 'ts-toolbelt'; -import { components } from './schema'; +import { components, paths } from './schema'; -type schemas = components['schemas']; +export type ImageCache = EntityState; + +export type ListImagesArgs = NonNullable< + paths['/api/v1/images/']['get']['parameters']['query'] +>; + +export type DeleteBoardResult = + paths['/api/v1/boards/{board_id}']['delete']['responses']['200']['content']['application/json']; + +export type ListBoardsArg = NonNullable< + paths['/api/v1/boards/']['get']['parameters']['query'] +>; + +export type UpdateBoardArg = + paths['/api/v1/boards/{board_id}']['patch']['parameters']['path'] & { + changes: paths['/api/v1/boards/{board_id}']['patch']['requestBody']['content']['application/json']; + }; + +/** + * This is an unsafe type; the object inside is not guaranteed to be valid. + */ +export type UnsafeImageMetadata = { + metadata: components['schemas']['CoreMetadata']; + graph: NonNullable; +}; /** * Marks the `type` property as required. Use for nodes. */ -type TypeReq = O.Required; +type TypeReq = O.Required; + +// Extracted types from API schema // App Info export type AppVersion = components['schemas']['AppVersion']; @@ -72,7 +99,6 @@ export type AnyModelConfig = | OnnxModelConfig; export type MergeModelConfig = components['schemas']['Body_merge_models']; -export type ConvertModelConfig = components['schemas']['Body_convert_model']; export type ImportModelConfig = components['schemas']['Body_import_model']; // Graphs diff --git a/invokeai/frontend/web/src/services/api/util.ts b/invokeai/frontend/web/src/services/api/util.ts new file mode 100644 index 0000000000..20c9baedbb --- /dev/null +++ b/invokeai/frontend/web/src/services/api/util.ts @@ -0,0 +1,56 @@ +import { + ASSETS_CATEGORIES, + IMAGE_CATEGORIES, +} from 'features/gallery/store/types'; +import { ImageCache, ImageDTO, ListImagesArgs } from './types'; +import { createEntityAdapter } from '@reduxjs/toolkit'; +import { dateComparator } from 'common/util/dateComparator'; +import queryString from 'query-string'; + +export const getIsImageInDateRange = ( + data: ImageCache | undefined, + imageDTO: ImageDTO +) => { + if (!data) { + return false; + } + const cacheImageDTOS = imagesSelectors.selectAll(data); + + if (cacheImageDTOS.length > 1) { + // Images are sorted by `created_at` DESC + // check if the image is newer than the oldest image in the cache + const createdDate = new Date(imageDTO.created_at); + const oldestImage = cacheImageDTOS[cacheImageDTOS.length - 1]; + if (!oldestImage) { + // satisfy TS gods, we already confirmed the array has more than one image + return false; + } + const oldestDate = new Date(oldestImage.created_at); + return createdDate >= oldestDate; + } else if ([0, 1].includes(cacheImageDTOS.length)) { + // if there are only 1 or 0 images in the cache, we consider the image to be in the date range + return true; + } + return false; +}; + +export const getCategories = (imageDTO: ImageDTO) => { + if (IMAGE_CATEGORIES.includes(imageDTO.image_category)) { + return IMAGE_CATEGORIES; + } + return ASSETS_CATEGORIES; +}; + +// The adapter is not actually the data store - it just provides helper functions to interact +// with some other store of data. We will use the RTK Query cache as that store. +export const imagesAdapter = createEntityAdapter({ + selectId: (image) => image.image_name, + sortComparer: (a, b) => dateComparator(b.updated_at, a.updated_at), +}); + +// Create selectors for the adapter. +export const imagesSelectors = imagesAdapter.getSelectors(); + +// Helper to create the url for the listImages endpoint. Also we use it to create the cache key. +export const getListImagesUrl = (queryArgs: ListImagesArgs) => + `images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`; diff --git a/invokeai/frontend/web/src/theme/util/generateColorPalette.ts b/invokeai/frontend/web/src/theme/util/generateColorPalette.ts index 6d90a070c0..63a5c06219 100644 --- a/invokeai/frontend/web/src/theme/util/generateColorPalette.ts +++ b/invokeai/frontend/web/src/theme/util/generateColorPalette.ts @@ -22,7 +22,7 @@ export function generateColorPalette( ]; const p = colorSteps.reduce((palette, step, index) => { - const A = alpha ? lightnessSteps[index] / 100 : 1; + const A = alpha ? (lightnessSteps[index] as number) / 100 : 1; // Lightness should be 50% for alpha colors const L = alpha ? 50 : lightnessSteps[colorSteps.length - 1 - index]; diff --git a/invokeai/frontend/web/tsconfig.json b/invokeai/frontend/web/tsconfig.json index e722e2f9a8..c43d5dd86d 100644 --- a/invokeai/frontend/web/tsconfig.json +++ b/invokeai/frontend/web/tsconfig.json @@ -13,6 +13,8 @@ "moduleResolution": "Node", // TODO: Disabled for IDE performance issues with our translation JSON // "resolveJsonModule": true, + "noUncheckedIndexedAccess": true, + "strictNullChecks": true, "isolatedModules": true, "noEmit": true, "jsx": "react-jsx", From d2bddf7f9161d8c18dda60ebebeedc35a0dd1a78 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 08:47:56 -0400 Subject: [PATCH 22/33] tweak formatting to accommodate longer runtimes --- invokeai/app/services/invocation_stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index aca1dba550..50320a6611 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -208,12 +208,12 @@ class InvocationStatsService(InvocationStatsServiceBase): total_time = 0 logger.info(f"Graph stats: {graph_id}") - logger.info("Node Calls Seconds VRAM Used") + logger.info("Node Calls Seconds VRAM Used") for node_type, stats in self._stats[graph_id].nodes.items(): - logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:4.3f}s {stats.max_vram:4.2f}G") + logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G") total_time += stats.time_used - logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:4.3f}s") + logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") if torch.cuda.is_available(): logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) From cfc3a20565810a72cb572019f1ad9ce99b4e1177 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 3 Aug 2023 15:47:59 -0400 Subject: [PATCH 23/33] autoAddBoardId should always be defined --- .../middleware/listenerMiddleware/listeners/imageUploaded.ts | 2 +- .../frontend/web/src/features/gallery/store/gallerySlice.ts | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index f488259eb7..6dc2d482a9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -40,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => { // default action - just upload and alert user if (postUploadAction?.type === 'TOAST') { const { toastOptions } = postUploadAction; - if (!autoAddBoardId) { + if (!autoAddBoardId || autoAddBoardId === 'none') { dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions })); } else { // Add this image to the board diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 3b0dd233f1..bc7acff6f4 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -41,6 +41,10 @@ export const gallerySlice = createSlice({ state.galleryView = 'images'; }, autoAddBoardIdChanged: (state, action: PayloadAction) => { + if (!action.payload) { + state.autoAddBoardId = 'none'; + return; + } state.autoAddBoardId = action.payload; }, galleryViewChanged: (state, action: PayloadAction) => { From 1ac14a1e43f5943435a8a28a1bf57d059b532af4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 31 Jul 2023 23:18:02 +0300 Subject: [PATCH 24/33] add sdxl lora support --- invokeai/app/invocations/compel.py | 30 +- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/model.py | 97 ++++ invokeai/app/invocations/sdxl.py | 16 +- invokeai/backend/model_management/__init__.py | 1 + invokeai/backend/model_management/lora.py | 436 +------------- .../backend/model_management/model_cache.py | 2 - .../backend/model_management/models/lora.py | 537 +++++++++++++++++- 8 files changed, 683 insertions(+), 438 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ada7a06a57..bbe372ff57 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -173,7 +173,7 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: - def run_clip_raw(self, context, clip_field, prompt, get_pooled): + def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -210,8 +210,8 @@ class SDXLPromptInvocationBase: # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') - with ModelPatcher.apply_lora_text_encoder( - text_encoder_info.context.model, _lora_loader() + with ModelPatcher.apply_lora( + text_encoder_info.context.model, _lora_loader(), lora_prefix ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, @@ -247,7 +247,7 @@ class SDXLPromptInvocationBase: return c, c_pooled, None - def run_clip_compel(self, context, clip_field, prompt, get_pooled): + def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -284,8 +284,8 @@ class SDXLPromptInvocationBase: # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') - with ModelPatcher.apply_lora_text_encoder( - text_encoder_info.context.model, _lora_loader() + with ModelPatcher.apply_lora( + text_encoder_info.context.model, _lora_loader(), lora_prefix ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, @@ -357,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False) + c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_") if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True) + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_") else: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True) + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -415,7 +415,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True) + # TODO: if there will appear lora for refiner - write proper prefix + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -467,11 +468,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False) + c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_") if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True) + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_") else: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True) + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -525,7 +526,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True) + # TODO: if there will appear lora for refiner - write proper prefix + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3edbe86150..6e2e0838bc 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings -from ...backend.model_management.lora import ModelPatcher +from ...backend.model_management import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c19e5c5c9a..d215d500a6 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation): return output +class SDXLLoraLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + # fmt: off + type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output" + + unet: Optional[UNetField] = Field(default=None, description="UNet submodel") + clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") + clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels") + # fmt: on + + +class SDXLLoraLoaderInvocation(BaseInvocation): + """Apply selected lora to unet and text_encoder.""" + + type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader" + + lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") + weight: float = Field(default=0.75, description="With what weight to apply lora") + + unet: Optional[UNetField] = Field(description="UNet model for applying lora") + clip: Optional[ClipField] = Field(description="Clip model for applying lora") + clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora") + + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "SDXL Lora Loader", + "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, + }, + } + + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") + + base_model = self.lora.base_model + lora_name = self.lora.model_name + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + ): + raise Exception(f"Unkown lora name: {lora_name}!") + + if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): + raise Exception(f'Lora "{lora_name}" already applied to unet') + + if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): + raise Exception(f'Lora "{lora_name}" already applied to clip') + + if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_name}" already applied to clip2') + + output = SDXLLoraLoaderOutput() + + if self.unet is not None: + output.unet = copy.deepcopy(self.unet) + output.unet.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip is not None: + output.clip = copy.deepcopy(self.clip) + output.clip.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip2 is not None: + output.clip2 = copy.deepcopy(self.clip2) + output.clip2.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + return output + + class VAEModelField(BaseModel): """Vae model field""" diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 7dfceba853..faa6b59782 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union from pydantic import Field, validator -from ...backend.model_management import ModelType, SubModelType +from ...backend.model_management import ModelType, SubModelType, ModelPatcher from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext @@ -293,10 +293,22 @@ class SDXLTextToLatentsInvocation(BaseInvocation): num_inference_steps = self.steps + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"}), + context=context, + ) + yield (lora_info.context.model, lora.weight) + del lora_info + return + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None - with unet_info as unet: + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: + scheduler.set_timesteps(num_inference_steps, device=unet.device) timesteps = scheduler.timesteps diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index cf057f3a89..8e083c1045 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -13,3 +13,4 @@ from .models import ( DuplicateModelException, ) from .model_merge import ModelMerger, MergeInterpolationMethod +from .lora import ModelPatcher diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..56f7a648c9 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -23,421 +23,6 @@ from transformers import CLIPTextModel, CLIPTokenizer # TODO: rename and split this file -class LoRALayerBase: - # rank: Optional[int] - # alpha: Optional[float] - # bias: Optional[torch.Tensor] - # layer_key: str - - # @property - # def scale(self): - # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - - def __init__( - self, - layer_key: str, - values: dict, - ): - if "alpha" in values: - self.alpha = values["alpha"].item() - else: - self.alpha = None - - if "bias_indices" in values and "bias_values" in values and "bias_size" in values: - self.bias = torch.sparse_coo_tensor( - values["bias_indices"], - values["bias_values"], - tuple(values["bias_size"]), - ) - - else: - self.bias = None - - self.rank = None # set in layer implementation - self.layer_key = layer_key - - def forward( - self, - module: torch.nn.Module, - input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure - multiplier: float, - ): - if type(module) == torch.nn.Conv2d: - op = torch.nn.functional.conv2d - extra_args = dict( - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - else: - op = torch.nn.functional.linear - extra_args = {} - - weight = self.get_weight() - - bias = self.bias if self.bias is not None else 0 - scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - return ( - op( - *input_h, - (weight + bias).view(module.weight.shape), - None, - **extra_args, - ) - * multiplier - * scale - ) - - def get_weight(self): - raise NotImplementedError() - - def calc_size(self) -> int: - model_size = 0 - for val in [self.bias]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) - - -# TODO: find and debug lora/locon with bias -class LoRALayer(LoRALayerBase): - # up: torch.Tensor - # mid: Optional[torch.Tensor] - # down: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.up = values["lora_up.weight"] - self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid = values["lora_mid.weight"] - else: - self.mid = None - - self.rank = self.down.shape[0] - - def get_weight(self): - if self.mid is not None: - up = self.up.reshape(self.up.shape[0], self.up.shape[1]) - down = self.down.reshape(self.down.shape[0], self.down.shape[1]) - weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) - else: - weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.up, self.mid, self.down]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) - - if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) - - -class LoHALayer(LoRALayerBase): - # w1_a: torch.Tensor - # w1_b: torch.Tensor - # w2_a: torch.Tensor - # w2_b: torch.Tensor - # t1: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.w1_a = values["hada_w1_a"] - self.w1_b = values["hada_w1_b"] - self.w2_a = values["hada_w2_a"] - self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1 = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2 = values["hada_t2"] - else: - self.t2 = None - - self.rank = self.w1_b.shape[0] - - def get_weight(self): - if self.t1 is None: - weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) - - else: - rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) - rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) - weight = rebuild1 * rebuild2 - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) - - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoKRLayer(LoRALayerBase): - # w1: Optional[torch.Tensor] = None - # w1_a: Optional[torch.Tensor] = None - # w1_b: Optional[torch.Tensor] = None - # w2: Optional[torch.Tensor] = None - # w2_a: Optional[torch.Tensor] = None - # w2_b: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - if "lokr_w1" in values: - self.w1 = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None - self.w1_a = values["lokr_w1_a"] - self.w1_b = values["lokr_w1_b"] - - if "lokr_w2" in values: - self.w2 = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None - self.w2_a = values["lokr_w2_a"] - self.w2_b = values["lokr_w2_b"] - - if "lokr_t2" in values: - self.t2 = values["lokr_t2"] - else: - self.t2 = None - - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] - else: - self.rank = None # unscaled - - def get_weight(self): - w1 = self.w1 - if w1 is None: - w1 = self.w1_a @ self.w1_b - - w2 = self.w2 - if w2 is None: - if self.t2 is None: - w2 = self.w2_a @ self.w2_b - else: - w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - w2 = w2.contiguous() - weight = torch.kron(w1, w2) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - if self.w1 is not None: - self.w1 = self.w1.to(device=device, dtype=dtype) - else: - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - - if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) - else: - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoRAModel: # (torch.nn.Module): - _name: str - layers: Dict[str, LoRALayer] - _device: torch.device - _dtype: torch.dtype - - def __init__( - self, - name: str, - layers: Dict[str, LoRALayer], - device: torch.device, - dtype: torch.dtype, - ): - self._name = name - self._device = device or torch.cpu - self._dtype = dtype or torch.float32 - self.layers = layers - - @property - def name(self): - return self._name - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> LoRAModel: - # TODO: try revert if exception? - for key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) - self._device = device - self._dtype = dtype - - def calc_size(self) -> int: - model_size = 0 - for _, layer in self.layers.items(): - model_size += layer.calc_size() - return model_size - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - if isinstance(file_path, str): - file_path = Path(file_path) - - model = cls( - device=device, - dtype=dtype, - name=file_path.stem, # TODO: - layers=dict(), - ) - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - state_dict = cls._group_state(state_dict) - - for layer_key, values in state_dict.items(): - # lora and locon - if "lora_down.weight" in values: - layer = LoRALayer(layer_key, values) - - # loha - elif "hada_w1_b" in values: - layer = LoHALayer(layer_key, values) - - # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: - layer = LoKRLayer(layer_key, values) - - else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return - - # lower memory consumption by removing already parsed layer values - state_dict[layer_key].clear() - - layer.to(device=device, dtype=dtype) - model.layers[layer_key] = layer - - return model - - @staticmethod - def _group_state(state_dict: dict): - state_dict_groupped = dict() - - for key, value in state_dict.items(): - stem, leaf = key.split(".", 1) - if stem not in state_dict_groupped: - state_dict_groupped[stem] = dict() - state_dict_groupped[stem][leaf] = value - - return state_dict_groupped - - """ loras = [ (lora_model1, 0.7), @@ -516,6 +101,27 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + @classmethod @contextmanager def apply_lora( diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index b4c3e48a48..71e1ebc0d4 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -28,8 +28,6 @@ import torch import logging import invokeai.backend.util.logging as logger -from invokeai.app.services.config import get_invokeai_config -from .lora import LoRAModel, TextualInversionModel from .models import BaseModelType, ModelType, SubModelType, ModelBase # Maximum size of the cache, in gigs diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 642f8bbeec..0351bf2652 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,7 +1,9 @@ import os import torch from enum import Enum -from typing import Optional, Union, Literal +from typing import Optional, Dict, Union, Literal, Any +from pathlib import Path +from safetensors.torch import load_file from .base import ( ModelBase, ModelConfigBase, @@ -13,9 +15,6 @@ from .base import ( ModelNotFoundException, ) -# TODO: naming -from ..lora import LoRAModel as LoRAModelRaw - class LoRAModelFormat(str, Enum): LyCORIS = "lycoris" @@ -50,6 +49,7 @@ class LoRAModel(ModelBase): model = LoRAModelRaw.from_checkpoint( file_path=self.model_path, dtype=torch_dtype, + base_model=self.base_model, ) self.model_size = model.calc_size() @@ -87,3 +87,532 @@ class LoRAModel(ModelBase): raise NotImplementedError("Diffusers lora not supported") else: return model_path + +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): + # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + + def __init__( + self, + layer_key: str, + values: dict, + ): + if "alpha" in values: + self.alpha = values["alpha"].item() + else: + self.alpha = None + + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: + self.bias = torch.sparse_coo_tensor( + values["bias_indices"], + values["bias_values"], + tuple(values["bias_size"]), + ) + + else: + self.bias = None + + self.rank = None # set in layer implementation + self.layer_key = layer_key + + def forward( + self, + module: torch.nn.Module, + input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure + multiplier: float, + ): + if type(module) == torch.nn.Conv2d: + op = torch.nn.functional.conv2d + extra_args = dict( + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + else: + op = torch.nn.functional.linear + extra_args = {} + + weight = self.get_weight() + + bias = self.bias if self.bias is not None else 0 + scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + return ( + op( + *input_h, + (weight + bias).view(module.weight.shape), + None, + **extra_args, + ) + * multiplier + * scale + ) + + def get_weight(self): + raise NotImplementedError() + + def calc_size(self) -> int: + model_size = 0 + for val in [self.bias]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + if self.bias is not None: + self.bias = self.bias.to(device=device, dtype=dtype) + +# TODO: find and debug lora/locon with bias +class LoRALayer(LoRALayerBase): + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.up = values["lora_up.weight"] + self.down = values["lora_down.weight"] + if "lora_mid.weight" in values: + self.mid = values["lora_mid.weight"] + else: + self.mid = None + + self.rank = self.down.shape[0] + + def get_weight(self): + if self.mid is not None: + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) + weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) + else: + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.up, self.mid, self.down]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + + if self.mid is not None: + self.mid = self.mid.to(device=device, dtype=dtype) + +class LoHALayer(LoRALayerBase): + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.w1_a = values["hada_w1_a"] + self.w1_b = values["hada_w1_b"] + self.w2_a = values["hada_w2_a"] + self.w2_b = values["hada_w2_b"] + + if "hada_t1" in values: + self.t1 = values["hada_t1"] + else: + self.t1 = None + + if "hada_t2" in values: + self.t2 = values["hada_t2"] + else: + self.t2 = None + + self.rank = self.w1_b.shape[0] + + def get_weight(self): + if self.t1 is None: + weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + + else: + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) + weight = rebuild1 * rebuild2 + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + if self.t1 is not None: + self.t1 = self.t1.to(device=device, dtype=dtype) + + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + +class LoKRLayer(LoRALayerBase): + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + if "lokr_w1" in values: + self.w1 = values["lokr_w1"] + self.w1_a = None + self.w1_b = None + else: + self.w1 = None + self.w1_a = values["lokr_w1_a"] + self.w1_b = values["lokr_w1_b"] + + if "lokr_w2" in values: + self.w2 = values["lokr_w2"] + self.w2_a = None + self.w2_b = None + else: + self.w2 = None + self.w2_a = values["lokr_w2_a"] + self.w2_b = values["lokr_w2_b"] + + if "lokr_t2" in values: + self.t2 = values["lokr_t2"] + else: + self.t2 = None + + if "lokr_w1_b" in values: + self.rank = values["lokr_w1_b"].shape[0] + elif "lokr_w2_b" in values: + self.rank = values["lokr_w2_b"].shape[0] + else: + self.rank = None # unscaled + + def get_weight(self): + w1 = self.w1 + if w1 is None: + w1 = self.w1_a @ self.w1_b + + w2 = self.w2 + if w2 is None: + if self.t2 is None: + w2 = self.w2_a @ self.w2_b + else: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + weight = torch.kron(w1, w2) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + if self.w1 is not None: + self.w1 = self.w1.to(device=device, dtype=dtype) + else: + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + + if self.w2 is not None: + self.w2 = self.w2.to(device=device, dtype=dtype) + else: + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + +# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix +class LoRAModelRaw: # (torch.nn.Module): + _name: str + layers: Dict[str, LoRALayer] + _device: torch.device + _dtype: torch.dtype + + def __init__( + self, + name: str, + layers: Dict[str, LoRALayer], + device: torch.device, + dtype: torch.dtype, + ): + self._name = name + self._device = device or torch.cpu + self._dtype = dtype or torch.float32 + self.layers = layers + + @property + def name(self): + return self._name + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + # TODO: try revert if exception? + for key, layer in self.layers.items(): + layer.to(device=device, dtype=dtype) + self._device = device + self._dtype = dtype + + def calc_size(self) -> int: + model_size = 0 + for _, layer in self.layers.items(): + model_size += layer.calc_size() + return model_size + + @classmethod + def _convert_sdxl_compvis_keys(cls, state_dict): + new_state_dict = dict() + for full_key, value in state_dict.items(): + if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + continue # clip same + + if not full_key.startswith("lora_unet_"): + raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") + src_key = full_key.replace("lora_unet_", "") + try: + dst_key = None + while "_" in src_key: + if src_key in SDXL_UNET_COMPVIS_MAP: + dst_key = SDXL_UNET_COMPVIS_MAP[src_key] + break + src_key = "_".join(src_key.split('_')[:-1]) + + if dst_key is None: + raise Exception(f"Unknown sdxl lora key - {full_key}") + new_key = full_key.replace(src_key, dst_key) + except: + print(SDXL_UNET_COMPVIS_MAP) + raise + new_state_dict[new_key] = value + return new_state_dict + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ): + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + if isinstance(file_path, str): + file_path = Path(file_path) + + model = cls( + device=device, + dtype=dtype, + name=file_path.stem, # TODO: + layers=dict(), + ) + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + state_dict = cls._group_state(state_dict) + + if base_model == BaseModelType.StableDiffusionXL: + state_dict = cls._convert_sdxl_compvis_keys(state_dict) + + for layer_key, values in state_dict.items(): + # lora and locon + if "lora_down.weight" in values: + layer = LoRALayer(layer_key, values) + + # loha + elif "hada_w1_b" in values: + layer = LoHALayer(layer_key, values) + + # lokr + elif "lokr_w1_b" in values or "lokr_w1" in values: + layer = LoKRLayer(layer_key, values) + + else: + # TODO: diff/ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") + return + + # lower memory consumption by removing already parsed layer values + state_dict[layer_key].clear() + + layer.to(device=device, dtype=dtype) + model.layers[layer_key] = layer + + return model + + @staticmethod + def _group_state(state_dict: dict): + state_dict_groupped = dict() + + for key, value in state_dict.items(): + stem, leaf = key.split(".", 1) + if stem not in state_dict_groupped: + state_dict_groupped[stem] = dict() + state_dict_groupped[stem][leaf] = value + + return state_dict_groupped + + +def make_sdxl_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + +#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} +SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} From 1d5d187ba10f4d2dab3d2bb1213e41cae37d701c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 10:26:52 -0400 Subject: [PATCH 25/33] model probe detects sdxl lora models --- invokeai/app/invocations/sdxl.py | 4 +-- invokeai/backend/model_management/lora.py | 1 - .../backend/model_management/model_probe.py | 25 ++++++++++++++++--- .../backend/model_management/models/lora.py | 17 ++++++++++--- scripts/probe-model.py | 6 +++-- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index faa6b59782..aaa616a378 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -306,9 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: scheduler.set_timesteps(num_inference_steps, device=unet.device) timesteps = scheduler.timesteps diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 56f7a648c9..0a0ab3d629 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -101,7 +101,6 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield - @classmethod @contextmanager def apply_sdxl_lora_text_encoder( diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index c3964d760c..21462cf6e6 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint + + # SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future + # There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be + # misclassified as SD-1 + key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" + if key in checkpoint and checkpoint[key].shape[0] == 320: + return BaseModelType.StableDiffusion2 + + key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight" + if key in checkpoint: + return BaseModelType.StableDiffusionXL + key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + lora_token_vector_length = ( checkpoint[key1].shape[1] if key1 in checkpoint - else checkpoint[key2].shape[0] + else checkpoint[key2].shape[1] if key2 in checkpoint - else 768 + else checkpoint[key3].shape[0] + if key3 in checkpoint + else None ) + if lora_token_vector_length == 768: return BaseModelType.StableDiffusion1 elif lora_token_vector_length == 1024: return BaseModelType.StableDiffusion2 else: - return None + raise InvalidModelException(f"Unknown LoRA type") class TextualInversionCheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 0351bf2652..0870e78469 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -88,6 +88,7 @@ class LoRAModel(ModelBase): else: return model_path + class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] @@ -173,6 +174,7 @@ class LoRALayerBase: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) + # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): # up: torch.Tensor @@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase): if self.mid is not None: self.mid = self.mid.to(device=device, dtype=dtype) + class LoHALayer(LoRALayerBase): # w1_a: torch.Tensor # w1_b: torch.Tensor @@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase): if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) + class LoKRLayer(LoRALayerBase): # w1: Optional[torch.Tensor] = None # w1_a: Optional[torch.Tensor] = None @@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase): if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module): new_state_dict = dict() for full_key, value in state_dict.items(): if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - continue # clip same + continue # clip same if not full_key.startswith("lora_unet_"): raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") @@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module): if src_key in SDXL_UNET_COMPVIS_MAP: dst_key = SDXL_UNET_COMPVIS_MAP[src_key] break - src_key = "_".join(src_key.split('_')[:-1]) + src_key = "_".join(src_key.split("_")[:-1]) if dst_key is None: raise Exception(f"Unknown sdxl lora key - {full_key}") @@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map(): return unet_conversion_map -#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} -SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} + +# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} +SDXL_UNET_COMPVIS_MAP = { + f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") + for sd, hf in make_sdxl_unet_conversion_map() +} diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 7281dafc3f..4cf2c50263 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type") parser.add_argument( "model_path", type=Path, + nargs="+", ) args = parser.parse_args() -info = ModelProbe().probe(args.model_path) -print(info) +for path in args.model_path: + info = ModelProbe().probe(path) + print(f"{path}: {info}") From cff91f06d36376e7936db97761df9aaea4395d53 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 3 Aug 2023 19:04:44 +0300 Subject: [PATCH 26/33] Add lora apply in sdxl l2l node --- invokeai/app/invocations/sdxl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index aaa616a378..5bcd85db28 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -553,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): context=context, ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"}), + context=context, + ) + yield (lora_info.context.model, lora.weight) + del lora_info + return + do_classifier_free_guidance = True cross_attention_kwargs = None - with unet_info as unet: + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: # apply denoising_start num_inference_steps = self.steps scheduler.set_timesteps(num_inference_steps, device=unet.device) From 0d3c27f46c3276d86a51191ad3d036a2fe51f13f Mon Sep 17 00:00:00 2001 From: StAlKeR7779 Date: Fri, 4 Aug 2023 03:07:21 +0300 Subject: [PATCH 27/33] Fix typo Co-authored-by: Ryan Dick --- invokeai/app/invocations/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index d215d500a6..0d21f8f0ce 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -307,7 +307,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {lora_name}!") + raise Exception(f"Unknown lora name: {lora_name}!") if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): raise Exception(f'Lora "{lora_name}" already applied to unet') From 2f8b928486eb8d4e7c94a7eca122e1ab08fbc0e4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 1 Aug 2023 17:02:57 +0300 Subject: [PATCH 28/33] Add support for diff/full lora layers --- invokeai/backend/model_management/lora.py | 46 +++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..14a78693ea 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -325,6 +325,43 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self): + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + class LoRAModel: # (torch.nn.Module): _name: str layers: Dict[str, LoRALayer] @@ -412,10 +449,13 @@ class LoRAModel: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + elif "diff" in values: + layer = FullLayer(layer_key, values) + else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return + # TODO: ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() From 7d0cc6ec3f97a4d0c8b61c6af6e3bc4e29c4f1b9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 3 Aug 2023 11:18:22 +1000 Subject: [PATCH 29/33] chore: black --- invokeai/backend/model_management/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 14a78693ea..9f196d659d 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -359,7 +359,7 @@ class FullLayer(LoRALayerBase): ): super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype) class LoRAModel: # (torch.nn.Module): From f0613bb0ef1642e3c34ba7bc3146294d5de15ad2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 4 Aug 2023 19:53:27 +0300 Subject: [PATCH 30/33] Fix merge conflict resolve - restore full/diff layer support --- .../backend/model_management/models/lora.py | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 0870e78469..1983c05503 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -391,6 +391,43 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self): + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -510,10 +547,13 @@ class LoRAModelRaw: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + elif "diff" in values: + layer = FullLayer(layer_key, values) + else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return + # TODO: ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() @@ -536,6 +576,8 @@ class LoRAModelRaw: # (torch.nn.Module): return state_dict_groupped +# code from +# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 def make_sdxl_unet_conversion_map(): unet_conversion_map_layer = [] @@ -620,7 +662,6 @@ def make_sdxl_unet_conversion_map(): return unet_conversion_map -# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} SDXL_UNET_COMPVIS_MAP = { f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() From 04229082d61f10efd70b2ff289b03c46ae4571a8 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 1 Aug 2023 18:04:10 +0300 Subject: [PATCH 31/33] Provide ti name from model manager, not from ti itself --- invokeai/app/invocations/compel.py | 15 +++++++----- invokeai/app/invocations/onnx.py | 13 +++------- invokeai/backend/model_management/lora.py | 30 +++++++++++------------ 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index bbe372ff57..c11ebd3f56 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -108,14 +108,15 @@ class CompelInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback @@ -196,14 +197,15 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback @@ -270,14 +272,15 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append( + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, context=context, ).context.model - ) + )) except ModelNotFoundException: # print(e) # import traceback diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 2bec128b87..dec5b939a0 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation): **self.clip.text_encoder.dict(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] loras = [ (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras @@ -75,20 +74,14 @@ class ONNXPromptInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append( - # stack.enter_context( - # context.services.model_manager.get_model( - # model_name=name, - # base_model=self.clip.text_encoder.base_model, - # model_type=ModelType.TextualInversion, - # ) - # ) + ti_list.append(( + name, context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, ).context.model - ) + )) except Exception: # print(e) # import traceback diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 7ccf5e57ae..e8e2b3f51f 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -164,7 +164,7 @@ class ModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Any], + ti_list: List[Tuple[str, Any]], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: init_tokens_count = None new_tokens_added = None @@ -174,27 +174,27 @@ class ModelPatcher: ti_manager = TextualInversionManager(ti_tokenizer) init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings - def _get_trigger(ti, index): - trigger = ti.name + def _get_trigger(ti_name, index): + trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 - for ti in ti_list: + for ti_name, ti in ti_list: for i in range(ti.embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) model_embeddings = text_encoder.get_input_embeddings() - for ti in ti_list: + for ti_name, ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i] - trigger = _get_trigger(ti, i) + trigger = _get_trigger(ti_name, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: @@ -239,7 +239,6 @@ class ModelPatcher: class TextualInversionModel: - name: str embedding: torch.Tensor # [n, 768]|[n, 1280] @classmethod @@ -253,7 +252,6 @@ class TextualInversionModel: file_path = Path(file_path) result = cls() # TODO: - result.name = file_path.stem # TODO: if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") @@ -430,7 +428,7 @@ class ONNXModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, - ti_list: List[Any], + ti_list: List[Tuple[str, Any]], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: from .models.base import IAIOnnxRuntimeModel @@ -443,17 +441,17 @@ class ONNXModelPatcher: ti_tokenizer = copy.deepcopy(tokenizer) ti_manager = TextualInversionManager(ti_tokenizer) - def _get_trigger(ti, index): - trigger = ti.name + def _get_trigger(ti_name, index): + trigger = ti_name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 - for ti in ti_list: + for ti_name, ti in ti_list: for i in range(ti.embedding.shape[0]): - new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) # modify text_encoder orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] @@ -463,11 +461,11 @@ class ONNXModelPatcher: axis=0, ) - for ti in ti_list: + for ti_name, ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i].detach().numpy() - trigger = _get_trigger(ti, i) + trigger = _get_trigger(ti_name, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: From 6ad565d84c22243838cd5fca1c579267ac3f4de5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 19:01:05 -0400 Subject: [PATCH 32/33] folded in changes from 4099 --- invokeai/app/invocations/compel.py | 60 ++++++++++--------- invokeai/app/invocations/onnx.py | 18 +++--- .../backend/model_management/model_cache.py | 2 +- .../backend/model_management/model_manager.py | 6 +- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index c11ebd3f56..7c3ce7a819 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -108,15 +108,17 @@ class CompelInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) except ModelNotFoundException: # print(e) # import traceback @@ -197,15 +199,17 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=clip_field.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) except ModelNotFoundException: # print(e) # import traceback @@ -272,15 +276,17 @@ class SDXLPromptInvocationBase: for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=clip_field.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) except ModelNotFoundException: # print(e) # import traceback diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index dec5b939a0..fe9a64552e 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -74,14 +74,16 @@ class ONNXPromptInvocation(BaseInvocation): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: - ti_list.append(( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - )) + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model, + ) + ) except Exception: # print(e) # import traceback diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 71e1ebc0d4..2b8d020269 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -186,7 +186,7 @@ class ModelCache(object): cache_entry = self._cached_models.get(key, None) if cache_entry is None: self.logger.info( - f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}" + f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" ) # this will remove older cached models until diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 832a96e18f..3fd59d1533 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -670,7 +670,7 @@ class ModelManager(object): # TODO: if path changed and old_model.path inside models folder should we delete this too? # remove conversion cache as config changed - old_model_path = self.app_config.root_path / old_model.path + old_model_path = self.resolve_model_path(old_model.path) old_model_cache = self._get_model_cache_path(old_model_path) if old_model_cache.exists(): if old_model_cache.is_dir(): @@ -780,7 +780,7 @@ class ModelManager(object): model_type, **submodel, ) - checkpoint_path = self.app_config.root_path / info["path"] + checkpoint_path = self.resolve_model_path(info["path"]) old_diffusers_path = self.resolve_model_path(model.location) new_diffusers_path = ( dest_directory or self.app_config.models_path / base_model.value / model_type.value @@ -992,7 +992,7 @@ class ModelManager(object): model_manager=self, prediction_type_helper=ask_user_for_prediction_type, ) - known_paths = {config.root_path / x["path"] for x in self.list_models()} + known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()} directories = { config.root_path / x for x in [ From 1b158f62c4d4a98da1d5aac2288ad724b4b8a87a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 19:26:42 -0400 Subject: [PATCH 33/33] resolve vae overrides correctly --- invokeai/backend/model_management/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 3fd59d1533..a3b0d4e04a 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -472,7 +472,7 @@ class ModelManager(object): if submodel_type is not None and hasattr(model_config, submodel_type): override_path = getattr(model_config, submodel_type) if override_path: - model_path = self.app_config.root_path / override_path + model_path = self.resolve_path(override_path) model_type = submodel_type submodel_type = None model_class = MODEL_CLASSES[base_model][model_type]