diff --git a/docs/installation/INSTALLATION.md b/docs/installation/INSTALLATION.md index b6f251fe48..ec5e2492b6 100644 --- a/docs/installation/INSTALLATION.md +++ b/docs/installation/INSTALLATION.md @@ -25,10 +25,10 @@ This method is recommended for experienced users and developers #### [Docker Installation](040_INSTALL_DOCKER.md) This method is recommended for those familiar with running Docker containers ### Other Installation Guides - - [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md) - - [XFormers](installation/070_INSTALL_XFORMERS.md) - - [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md) - - [Installing New Models](installation/050_INSTALLING_MODELS.md) + - [PyPatchMatch](060_INSTALL_PATCHMATCH.md) + - [XFormers](070_INSTALL_XFORMERS.md) + - [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md) + - [Installing New Models](050_INSTALLING_MODELS.md) ## :fontawesome-solid-computer: Hardware Requirements diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 50320a6611..35c3a5e403 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -29,6 +29,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme writes to the system log is stored in InvocationServices.performance_statistics. """ +import psutil import time from abc import ABC, abstractmethod from contextlib import AbstractContextManager @@ -42,6 +43,11 @@ import invokeai.backend.util.logging as logger from ..invocations.baseinvocation import BaseInvocation from .graph import GraphExecutionState from .item_storage import ItemStorageABC +from .model_manager_service import ModelManagerService +from invokeai.backend.model_management.model_cache import CacheStats + +# size of GIG in bytes +GIG = 1073741824 class InvocationStatsServiceBase(ABC): @@ -89,6 +95,8 @@ class InvocationStatsServiceBase(ABC): invocation_type: str, time_used: float, vram_used: float, + ram_used: float, + ram_changed: float, ): """ Add timing information on execution of a node. Usually @@ -97,6 +105,8 @@ class InvocationStatsServiceBase(ABC): :param invocation_type: String literal type of the node :param time_used: Time used by node's exection (sec) :param vram_used: Maximum VRAM used during exection (GB) + :param ram_used: Current RAM available (GB) + :param ram_changed: Change in RAM usage over course of the run (GB) """ pass @@ -115,6 +125,9 @@ class NodeStats: calls: int = 0 time_used: float = 0.0 # seconds max_vram: float = 0.0 # GB + cache_hits: int = 0 + cache_misses: int = 0 + cache_high_watermark: int = 0 @dataclass @@ -133,31 +146,62 @@ class InvocationStatsService(InvocationStatsServiceBase): self.graph_execution_manager = graph_execution_manager # {graph_id => NodeLog} self._stats: Dict[str, NodeLog] = {} + self._cache_stats: Dict[str, CacheStats] = {} + self.ram_used: float = 0.0 + self.ram_changed: float = 0.0 class StatsContext: - def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"): + """Context manager for collecting statistics.""" + + invocation: BaseInvocation = None + collector: "InvocationStatsServiceBase" = None + graph_id: str = None + start_time: int = 0 + ram_used: int = 0 + model_manager: ModelManagerService = None + + def __init__( + self, + invocation: BaseInvocation, + graph_id: str, + model_manager: ModelManagerService, + collector: "InvocationStatsServiceBase", + ): + """Initialize statistics for this run.""" self.invocation = invocation self.collector = collector self.graph_id = graph_id self.start_time = 0 + self.ram_used = 0 + self.model_manager = model_manager def __enter__(self): self.start_time = time.time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() + self.ram_used = psutil.Process().memory_info().rss + if self.model_manager: + self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id]) def __exit__(self, *args): + """Called on exit from the context.""" + ram_used = psutil.Process().memory_info().rss + self.collector.update_mem_stats( + ram_used=ram_used / GIG, + ram_changed=(ram_used - self.ram_used) / GIG, + ) self.collector.update_invocation_stats( - self.graph_id, - self.invocation.type, - time.time() - self.start_time, - torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, + graph_id=self.graph_id, + invocation_type=self.invocation.type, + time_used=time.time() - self.start_time, + vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0, ) def collect_stats( self, invocation: BaseInvocation, graph_execution_state_id: str, + model_manager: ModelManagerService, ) -> StatsContext: """ Return a context object that will capture the statistics. @@ -166,7 +210,8 @@ class InvocationStatsService(InvocationStatsServiceBase): """ if not self._stats.get(graph_execution_state_id): # first time we're seeing this self._stats[graph_execution_state_id] = NodeLog() - return self.StatsContext(invocation, graph_execution_state_id, self) + self._cache_stats[graph_execution_state_id] = CacheStats() + return self.StatsContext(invocation, graph_execution_state_id, model_manager, self) def reset_all_stats(self): """Zero all statistics""" @@ -179,13 +224,36 @@ class InvocationStatsService(InvocationStatsServiceBase): except KeyError: logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") - def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float): + def update_mem_stats( + self, + ram_used: float, + ram_changed: float, + ): + """ + Update the collector with RAM memory usage info. + + :param ram_used: How much RAM is currently in use. + :param ram_changed: How much RAM changed since last generation. + """ + self.ram_used = ram_used + self.ram_changed = ram_changed + + def update_invocation_stats( + self, + graph_id: str, + invocation_type: str, + time_used: float, + vram_used: float, + ): """ Add timing information on execution of a node. Usually used internally. :param graph_id: ID of the graph that is currently executing :param invocation_type: String literal type of the node - :param time_used: Floating point seconds used by node's exection + :param time_used: Time used by node's exection (sec) + :param vram_used: Maximum VRAM used during exection (GB) + :param ram_used: Current RAM available (GB) + :param ram_changed: Change in RAM usage over course of the run (GB) """ if not self._stats[graph_id].nodes.get(invocation_type): self._stats[graph_id].nodes[invocation_type] = NodeStats() @@ -197,7 +265,7 @@ class InvocationStatsService(InvocationStatsServiceBase): def log_stats(self): """ Send the statistics to the system logger at the info level. - Stats will only be printed if when the execution of the graph + Stats will only be printed when the execution of the graph is complete. """ completed = set() @@ -208,16 +276,30 @@ class InvocationStatsService(InvocationStatsServiceBase): total_time = 0 logger.info(f"Graph stats: {graph_id}") - logger.info("Node Calls Seconds VRAM Used") + logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}") for node_type, stats in self._stats[graph_id].nodes.items(): - logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G") + logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G") total_time += stats.time_used + cache_stats = self._cache_stats[graph_id] + hwm = cache_stats.high_watermark / GIG + tot = cache_stats.cache_size / GIG + loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG + logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") + logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)") + logger.info(f"RAM used to load models: {loaded:4.2f}G") if torch.cuda.is_available(): - logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) + logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG)) + logger.info("RAM cache statistics:") + logger.info(f" Model cache hits: {cache_stats.hits}") + logger.info(f" Model cache misses: {cache_stats.misses}") + logger.info(f" Models cached: {cache_stats.in_cache}") + logger.info(f" Models cleared from cache: {cache_stats.cleared}") + logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G") completed.add(graph_id) for graph_id in completed: del self._stats[graph_id] + del self._cache_stats[graph_id] diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index fd14e26364..675bc71257 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -22,6 +22,7 @@ from invokeai.backend.model_management import ( ModelNotFoundException, ) from invokeai.backend.model_management.model_search import FindModels +from invokeai.backend.model_management.model_cache import CacheStats import torch from invokeai.app.models.exceptions import CanceledException @@ -276,6 +277,13 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def collect_cache_stats(self, cache_stats: CacheStats): + """ + Reset model cache statistics for graph with graph_id. + """ + pass + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -500,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase): self.logger.debug(f"convert model {model_name}") return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) + def collect_cache_stats(self, cache_stats: CacheStats): + """ + Reset model cache statistics for graph with graph_id. + """ + self.mgr.cache.stats = cache_stats + def commit(self, conf_file: Optional[Path] = None): """ Write current configuration out to the indicated file. diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index b8c2f93e93..37da17d318 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -86,7 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: - with statistics.collect_stats(invocation, graph_execution_state.id): + graph_id = graph_execution_state.id + model_manager = self.__invoker.services.model_manager + with statistics.collect_stats(invocation, graph_id, model_manager): # use the internal invoke_internal(), which wraps the node's invoke() method in # this accomodates nodes which require a value, but get it only from a # connection diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 2b8d020269..a11e0a8a8f 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -21,12 +21,12 @@ import os import sys import hashlib from contextlib import suppress +from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Union, types, Optional, Type, Any import torch -import logging import invokeai.backend.util.logging as logger from .models import BaseModelType, ModelType, SubModelType, ModelBase @@ -41,6 +41,18 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 GIG = 1073741824 +@dataclass +class CacheStats(object): + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + # {submodel_key => size} + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelLocker(object): "Forward declaration" pass @@ -115,6 +127,9 @@ class ModelCache(object): self.sha_chunksize = sha_chunksize self.logger = logger + # used for stats collection + self.stats = None + self._cached_models = dict() self._cache_stack = list() @@ -181,13 +196,14 @@ class ModelCache(object): model_type=model_type, submodel_type=submodel, ) - # TODO: lock for no copies on simultaneous calls? cache_entry = self._cached_models.get(key, None) if cache_entry is None: self.logger.info( f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" ) + if self.stats: + self.stats.misses += 1 # this will remove older cached models until # there is sufficient room to load the requested model @@ -201,6 +217,17 @@ class ModelCache(object): cache_entry = _CacheRecord(self, model, mem_used) self._cached_models[key] = cache_entry + else: + if self.stats: + self.stats.hits += 1 + + if self.stats: + self.stats.cache_size = self.max_cache_size * GIG + self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[key] = max( + self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel) + ) with suppress(Exception): self._cache_stack.remove(key) @@ -280,14 +307,14 @@ class ModelCache(object): """ Given the HF repo id or path to a model on disk, returns a unique hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs + :param model_path: Path to model file/directory on disk. """ return self._local_model_hash(model_path) def cache_size(self) -> float: - "Return the current size of the cache, in GB" - current_cache_size = sum([m.size for m in self._cached_models.values()]) - return current_cache_size / GIG + """Return the current size of the cache, in GB.""" + return self._cache_size() / GIG def _has_cuda(self) -> bool: return self.execution_device.type == "cuda" @@ -310,12 +337,15 @@ class ModelCache(object): f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}" ) + def _cache_size(self) -> int: + return sum([m.size for m in self._cached_models.values()]) + def _make_cache_room(self, model_size): # calculate how much memory this model will require # multiplier = 2 if self.precision==torch.float32 else 1 bytes_needed = model_size maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes - current_size = sum([m.size for m in self._cached_models.values()]) + current_size = self._cache_size() if current_size + bytes_needed > maximum_size: self.logger.debug( @@ -364,6 +394,8 @@ class ModelCache(object): f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) current_size -= cache_entry.size + if self.stats: + self.stats.cleared += 1 del self._cache_stack[pos] del self._cached_models[model_key] del cache_entry diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index e739855b9e..f16855e775 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -240,6 +240,7 @@ class InvokeAIDiffuserComponent: controlnet_cond=control_datum.image_tensor, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale encoder_attention_mask=encoder_attention_mask, + added_cond_kwargs=added_cond_kwargs, guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel return_dict=False, ) diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 4710682ac1..89b3da5a37 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -4,8 +4,15 @@ import torch from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalControlnetMixin from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.embeddings import ( + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -18,10 +25,11 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel import diffusers from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module +# TODO: create PR to diffusers # Modified ControlNetModel with encoder_attention_mask argument added -class ControlNetModel(ModelMixin, ConfigMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model. @@ -52,12 +60,25 @@ class ControlNetModel(ModelMixin, ConfigMixin): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. @@ -98,10 +119,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", @@ -109,6 +135,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, ): super().__init__() @@ -136,6 +163,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 @@ -145,16 +175,43 @@ class ControlNetModel(ModelMixin, ConfigMixin): # time time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -178,6 +235,29 @@ class ControlNetModel(ModelMixin, ConfigMixin): else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], @@ -212,6 +292,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): down_block = get_down_block( down_block_type, num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -248,6 +329,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, @@ -277,7 +359,22 @@ class ControlNetModel(ModelMixin, ConfigMixin): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, @@ -463,6 +560,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, guess_mode: bool = False, @@ -486,7 +584,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If @@ -549,6 +649,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None if self.class_embedding is not None: if class_labels is None: @@ -560,11 +661,34 @@ class ControlNetModel(ModelMixin, ConfigMixin): class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + # 2. pre-process sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample = sample + controlnet_cond # 3. down diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index fbae5b4a30..f41da82e07 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -506,10 +506,14 @@ "maskAdjustmentsHeader": "Mask Adjustments", "maskBlur": "Mask Blur", "maskBlurMethod": "Mask Blur Method", + "seamPaintingHeader": "Seam Painting", "seamSize": "Seam Size", "seamBlur": "Seam Blur", - "seamStrength": "Seam Strength", "seamSteps": "Seam Steps", + "seamStrength": "Seam Strength", + "seamThreshold": "Seam Threshold", + "seamLowThreshold": "Low", + "seamHighThreshold": "High", "scaleBeforeProcessing": "Scale Before Processing", "scaledWidth": "Scaled W", "scaledHeight": "Scaled H", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index cdfae0095e..b419e98782 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -121,7 +121,7 @@ export const addRequestedMultipleImageDeletionListener = () => { effect: async (action, { dispatch, getState }) => { const { imageDTOs, imagesUsage } = action.payload; - if (imageDTOs.length < 1 || imagesUsage.length < 1) { + if (imageDTOs.length <= 1 || imagesUsage.length <= 1) { // handle singles in separate listener return; } diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts index b9c3d5e28e..a949c88e5f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts @@ -32,6 +32,7 @@ import { MAIN_MODEL_LOADER, MASK_BLUR, MASK_COMBINE, + MASK_EDGE, MASK_FROM_ALPHA, MASK_RESIZE_DOWN, MASK_RESIZE_UP, @@ -40,6 +41,10 @@ import { POSITIVE_CONDITIONING, RANDOM_INT, RANGE_OF_SIZE, + SEAM_FIX_DENOISE_LATENTS, + SEAM_MASK_COMBINE, + SEAM_MASK_RESIZE_DOWN, + SEAM_MASK_RESIZE_UP, } from './constants'; /** @@ -67,6 +72,12 @@ export const buildCanvasOutpaintGraph = ( shouldUseCpuNoise, maskBlur, maskBlurMethod, + seamSize, + seamBlur, + seamSteps, + seamStrength, + seamLowThreshold, + seamHighThreshold, tileSize, infillMethod, clipSkip, @@ -130,6 +141,11 @@ export const buildCanvasOutpaintGraph = ( is_intermediate: true, mask2: canvasMaskImage, }, + [SEAM_MASK_COMBINE]: { + type: 'mask_combine', + id: MASK_COMBINE, + is_intermediate: true, + }, [MASK_BLUR]: { type: 'img_blur', id: MASK_BLUR, @@ -165,6 +181,25 @@ export const buildCanvasOutpaintGraph = ( denoising_start: 1 - strength, denoising_end: 1, }, + [MASK_EDGE]: { + type: 'mask_edge', + id: MASK_EDGE, + is_intermediate: true, + edge_size: seamSize, + edge_blur: seamBlur, + low_threshold: seamLowThreshold, + high_threshold: seamHighThreshold, + }, + [SEAM_FIX_DENOISE_LATENTS]: { + type: 'denoise_latents', + id: SEAM_FIX_DENOISE_LATENTS, + is_intermediate: true, + steps: seamSteps, + cfg_scale: cfg_scale, + scheduler: scheduler, + denoising_start: 1 - seamStrength, + denoising_end: 1, + }, [LATENTS_TO_IMAGE]: { type: 'l2i', id: LATENTS_TO_IMAGE, @@ -333,12 +368,63 @@ export const buildCanvasOutpaintGraph = ( field: 'seed', }, }, - // Decode the result from Inpaint + // Seam Paint + { + source: { + node_id: MAIN_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'noise', + }, + }, { source: { node_id: DENOISE_LATENTS, field: 'latents', }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'latents', + }, + }, + // Decode the result from Inpaint + { + source: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'latents', + }, destination: { node_id: LATENTS_TO_IMAGE, field: 'latents', @@ -348,7 +434,6 @@ export const buildCanvasOutpaintGraph = ( }; // Add Infill Nodes - if (infillMethod === 'patchmatch') { graph.nodes[INPAINT_INFILL] = { type: 'infill_patchmatch', @@ -378,6 +463,13 @@ export const buildCanvasOutpaintGraph = ( width: scaledWidth, height: scaledHeight, }; + graph.nodes[SEAM_MASK_RESIZE_UP] = { + type: 'img_resize', + id: SEAM_MASK_RESIZE_UP, + is_intermediate: true, + width: scaledWidth, + height: scaledHeight, + }; graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = { type: 'img_resize', id: INPAINT_IMAGE_RESIZE_DOWN, @@ -399,6 +491,13 @@ export const buildCanvasOutpaintGraph = ( width: width, height: height, }; + graph.nodes[SEAM_MASK_RESIZE_DOWN] = { + type: 'img_resize', + id: SEAM_MASK_RESIZE_DOWN, + is_intermediate: true, + width: width, + height: height, + }; graph.nodes[NOISE] = { ...(graph.nodes[NOISE] as NoiseInvocation), @@ -440,6 +539,57 @@ export const buildCanvasOutpaintGraph = ( field: 'image', }, }, + // Seam Paint Mask + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'image', + }, + destination: { + node_id: MASK_EDGE, + field: 'image', + }, + }, + { + source: { + node_id: MASK_EDGE, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_RESIZE_UP, + field: 'image', + }, + }, + { + source: { + node_id: SEAM_MASK_RESIZE_UP, + field: 'image', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'mask', + }, + }, + { + source: { + node_id: MASK_BLUR, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask1', + }, + }, + { + source: { + node_id: SEAM_MASK_RESIZE_UP, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask2', + }, + }, // Resize Results Down { source: { @@ -453,7 +603,7 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_BLUR, + node_id: MASK_RESIZE_UP, field: 'image', }, destination: { @@ -461,6 +611,16 @@ export const buildCanvasOutpaintGraph = ( field: 'image', }, }, + { + source: { + node_id: SEAM_MASK_COMBINE, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_RESIZE_DOWN, + field: 'image', + }, + }, { source: { node_id: INPAINT_INFILL, @@ -494,7 +654,7 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_DOWN, + node_id: SEAM_MASK_RESIZE_DOWN, field: 'image', }, destination: { @@ -525,7 +685,7 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_DOWN, + node_id: SEAM_MASK_RESIZE_DOWN, field: 'image', }, destination: { @@ -553,7 +713,6 @@ export const buildCanvasOutpaintGraph = ( }; graph.nodes[MASK_BLUR] = { ...(graph.nodes[MASK_BLUR] as ImageBlurInvocation), - image: canvasMaskImage, }; graph.edges.push( @@ -568,6 +727,47 @@ export const buildCanvasOutpaintGraph = ( field: 'image', }, }, + // Seam Paint Mask + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'image', + }, + destination: { + node_id: MASK_EDGE, + field: 'image', + }, + }, + { + source: { + node_id: MASK_EDGE, + field: 'image', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'mask', + }, + }, + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask1', + }, + }, + { + source: { + node_id: MASK_EDGE, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask2', + }, + }, // Color Correct The Inpainted Result { source: { @@ -591,7 +791,7 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_BLUR, + node_id: SEAM_MASK_COMBINE, field: 'image', }, destination: { @@ -622,7 +822,7 @@ export const buildCanvasOutpaintGraph = ( }, { source: { - node_id: MASK_BLUR, + node_id: SEAM_MASK_COMBINE, field: 'image', }, destination: { diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts index 4d098f959f..1cc268c03d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts @@ -29,6 +29,7 @@ import { LATENTS_TO_IMAGE, MASK_BLUR, MASK_COMBINE, + MASK_EDGE, MASK_FROM_ALPHA, MASK_RESIZE_DOWN, MASK_RESIZE_UP, @@ -40,6 +41,10 @@ import { SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER, + SEAM_FIX_DENOISE_LATENTS, + SEAM_MASK_COMBINE, + SEAM_MASK_RESIZE_DOWN, + SEAM_MASK_RESIZE_UP, } from './constants'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; @@ -67,6 +72,12 @@ export const buildCanvasSDXLOutpaintGraph = ( shouldUseCpuNoise, maskBlur, maskBlurMethod, + seamSize, + seamBlur, + seamSteps, + seamStrength, + seamLowThreshold, + seamHighThreshold, tileSize, infillMethod, } = state.generation; @@ -133,6 +144,11 @@ export const buildCanvasSDXLOutpaintGraph = ( is_intermediate: true, mask2: canvasMaskImage, }, + [SEAM_MASK_COMBINE]: { + type: 'mask_combine', + id: MASK_COMBINE, + is_intermediate: true, + }, [MASK_BLUR]: { type: 'img_blur', id: MASK_BLUR, @@ -170,6 +186,25 @@ export const buildCanvasSDXLOutpaintGraph = ( : 1 - strength, denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, }, + [MASK_EDGE]: { + type: 'mask_edge', + id: MASK_EDGE, + is_intermediate: true, + edge_size: seamSize, + edge_blur: seamBlur, + low_threshold: seamLowThreshold, + high_threshold: seamHighThreshold, + }, + [SEAM_FIX_DENOISE_LATENTS]: { + type: 'denoise_latents', + id: SEAM_FIX_DENOISE_LATENTS, + is_intermediate: true, + steps: seamSteps, + cfg_scale: cfg_scale, + scheduler: scheduler, + denoising_start: 1 - seamStrength, + denoising_end: 1, + }, [LATENTS_TO_IMAGE]: { type: 'l2i', id: LATENTS_TO_IMAGE, @@ -347,12 +382,63 @@ export const buildCanvasSDXLOutpaintGraph = ( field: 'seed', }, }, - // Decode inpainted latents to image + // Seam Paint + { + source: { + node_id: SDXL_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'unet', + }, + }, + { + source: { + node_id: POSITIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'positive_conditioning', + }, + }, + { + source: { + node_id: NEGATIVE_CONDITIONING, + field: 'conditioning', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'negative_conditioning', + }, + }, + { + source: { + node_id: NOISE, + field: 'noise', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'noise', + }, + }, { source: { node_id: SDXL_DENOISE_LATENTS, field: 'latents', }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'latents', + }, + }, + // Decode inpainted latents to image + { + source: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'latents', + }, destination: { node_id: LATENTS_TO_IMAGE, field: 'latents', @@ -392,6 +478,13 @@ export const buildCanvasSDXLOutpaintGraph = ( width: scaledWidth, height: scaledHeight, }; + graph.nodes[SEAM_MASK_RESIZE_UP] = { + type: 'img_resize', + id: SEAM_MASK_RESIZE_UP, + is_intermediate: true, + width: scaledWidth, + height: scaledHeight, + }; graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = { type: 'img_resize', id: INPAINT_IMAGE_RESIZE_DOWN, @@ -413,6 +506,13 @@ export const buildCanvasSDXLOutpaintGraph = ( width: width, height: height, }; + graph.nodes[SEAM_MASK_RESIZE_DOWN] = { + type: 'img_resize', + id: SEAM_MASK_RESIZE_DOWN, + is_intermediate: true, + width: width, + height: height, + }; graph.nodes[NOISE] = { ...(graph.nodes[NOISE] as NoiseInvocation), @@ -454,6 +554,57 @@ export const buildCanvasSDXLOutpaintGraph = ( field: 'image', }, }, + // Seam Paint Mask + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'image', + }, + destination: { + node_id: MASK_EDGE, + field: 'image', + }, + }, + { + source: { + node_id: MASK_EDGE, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_RESIZE_UP, + field: 'image', + }, + }, + { + source: { + node_id: SEAM_MASK_RESIZE_UP, + field: 'image', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'mask', + }, + }, + { + source: { + node_id: MASK_BLUR, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask1', + }, + }, + { + source: { + node_id: SEAM_MASK_RESIZE_UP, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask2', + }, + }, // Resize Results Down { source: { @@ -467,7 +618,7 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_BLUR, + node_id: MASK_RESIZE_UP, field: 'image', }, destination: { @@ -475,6 +626,16 @@ export const buildCanvasSDXLOutpaintGraph = ( field: 'image', }, }, + { + source: { + node_id: SEAM_MASK_COMBINE, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_RESIZE_DOWN, + field: 'image', + }, + }, { source: { node_id: INPAINT_INFILL, @@ -508,7 +669,7 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_DOWN, + node_id: SEAM_MASK_RESIZE_DOWN, field: 'image', }, destination: { @@ -539,7 +700,7 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_RESIZE_DOWN, + node_id: SEAM_MASK_RESIZE_DOWN, field: 'image', }, destination: { @@ -567,7 +728,6 @@ export const buildCanvasSDXLOutpaintGraph = ( }; graph.nodes[MASK_BLUR] = { ...(graph.nodes[MASK_BLUR] as ImageBlurInvocation), - image: canvasMaskImage, }; graph.edges.push( @@ -582,6 +742,47 @@ export const buildCanvasSDXLOutpaintGraph = ( field: 'image', }, }, + // Seam Paint Mask + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'image', + }, + destination: { + node_id: MASK_EDGE, + field: 'image', + }, + }, + { + source: { + node_id: MASK_EDGE, + field: 'image', + }, + destination: { + node_id: SEAM_FIX_DENOISE_LATENTS, + field: 'mask', + }, + }, + { + source: { + node_id: MASK_FROM_ALPHA, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask1', + }, + }, + { + source: { + node_id: MASK_EDGE, + field: 'image', + }, + destination: { + node_id: SEAM_MASK_COMBINE, + field: 'mask2', + }, + }, // Color Correct The Inpainted Result { source: { @@ -605,7 +806,7 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_BLUR, + node_id: SEAM_MASK_COMBINE, field: 'image', }, destination: { @@ -636,7 +837,7 @@ export const buildCanvasSDXLOutpaintGraph = ( }, { source: { - node_id: MASK_BLUR, + node_id: SEAM_MASK_COMBINE, field: 'image', }, destination: { @@ -669,7 +870,7 @@ export const buildCanvasSDXLOutpaintGraph = ( // Add Refiner if enabled if (shouldUseSDXLRefiner) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); + addSDXLRefinerToGraph(state, graph, SEAM_FIX_DENOISE_LATENTS); } // optionally add custom VAE diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 3e213120b3..1f6acd4e26 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -18,8 +18,6 @@ export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; export const CANVAS_OUTPUT = 'canvas_output'; -export const INPAINT = 'inpaint'; -export const INPAINT_SEAM_FIX = 'inpaint_seam_fix'; export const INPAINT_IMAGE = 'inpaint_image'; export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image'; export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up'; @@ -27,10 +25,14 @@ export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down'; export const INPAINT_INFILL = 'inpaint_infill'; export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down'; export const INPAINT_FINAL_IMAGE = 'inpaint_final_image'; +export const SEAM_FIX_DENOISE_LATENTS = 'seam_fix_denoise_latents'; export const MASK_FROM_ALPHA = 'tomask'; export const MASK_EDGE = 'mask_edge'; export const MASK_BLUR = 'mask_blur'; export const MASK_COMBINE = 'mask_combine'; +export const SEAM_MASK_COMBINE = 'seam_mask_combine'; +export const SEAM_MASK_RESIZE_UP = 'seam_mask_resize_up'; +export const SEAM_MASK_RESIZE_DOWN = 'seam_mask_resize_down'; export const MASK_RESIZE_UP = 'mask_resize_up'; export const MASK_RESIZE_DOWN = 'mask_resize_down'; export const COLOR_CORRECT = 'color_correct'; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamBlur.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamBlur.tsx new file mode 100644 index 0000000000..2ab048ce72 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamBlur.tsx @@ -0,0 +1,36 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISlider from 'common/components/IAISlider'; +import { setSeamBlur } from 'features/parameters/store/generationSlice'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; + +const ParamSeamBlur = () => { + const dispatch = useAppDispatch(); + const seamBlur = useAppSelector( + (state: RootState) => state.generation.seamBlur + ); + const { t } = useTranslation(); + + return ( + { + dispatch(setSeamBlur(v)); + }} + withInput + withSliderMarks + withReset + handleReset={() => { + dispatch(setSeamBlur(8)); + }} + /> + ); +}; + +export default memo(ParamSeamBlur); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse.tsx new file mode 100644 index 0000000000..23e06797e5 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse.tsx @@ -0,0 +1,27 @@ +import { Flex } from '@chakra-ui/react'; +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamSeamBlur from './ParamSeamBlur'; +import ParamSeamSize from './ParamSeamSize'; +import ParamSeamSteps from './ParamSeamSteps'; +import ParamSeamStrength from './ParamSeamStrength'; +import ParamSeamThreshold from './ParamSeamThreshold'; + +const ParamSeamPaintingCollapse = () => { + const { t } = useTranslation(); + + return ( + + + + + + + + + + ); +}; + +export default memo(ParamSeamPaintingCollapse); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamSize.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamSize.tsx new file mode 100644 index 0000000000..841e9555fd --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamSize.tsx @@ -0,0 +1,36 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISlider from 'common/components/IAISlider'; +import { setSeamSize } from 'features/parameters/store/generationSlice'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; + +const ParamSeamSize = () => { + const dispatch = useAppDispatch(); + const seamSize = useAppSelector( + (state: RootState) => state.generation.seamSize + ); + const { t } = useTranslation(); + + return ( + { + dispatch(setSeamSize(v)); + }} + withInput + withSliderMarks + withReset + handleReset={() => { + dispatch(setSeamSize(16)); + }} + /> + ); +}; + +export default memo(ParamSeamSize); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamSteps.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamSteps.tsx new file mode 100644 index 0000000000..e69339dbfe --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamSteps.tsx @@ -0,0 +1,36 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISlider from 'common/components/IAISlider'; +import { setSeamSteps } from 'features/parameters/store/generationSlice'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; + +const ParamSeamSteps = () => { + const dispatch = useAppDispatch(); + const seamSteps = useAppSelector( + (state: RootState) => state.generation.seamSteps + ); + const { t } = useTranslation(); + + return ( + { + dispatch(setSeamSteps(v)); + }} + withInput + withSliderMarks + withReset + handleReset={() => { + dispatch(setSeamSteps(20)); + }} + /> + ); +}; + +export default memo(ParamSeamSteps); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamStrength.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamStrength.tsx new file mode 100644 index 0000000000..3f0fa01fcb --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamStrength.tsx @@ -0,0 +1,36 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISlider from 'common/components/IAISlider'; +import { setSeamStrength } from 'features/parameters/store/generationSlice'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; + +const ParamSeamStrength = () => { + const dispatch = useAppDispatch(); + const seamStrength = useAppSelector( + (state: RootState) => state.generation.seamStrength + ); + const { t } = useTranslation(); + + return ( + { + dispatch(setSeamStrength(v)); + }} + withInput + withSliderMarks + withReset + handleReset={() => { + dispatch(setSeamStrength(0.7)); + }} + /> + ); +}; + +export default memo(ParamSeamStrength); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamThreshold.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamThreshold.tsx new file mode 100644 index 0000000000..f40491db98 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamThreshold.tsx @@ -0,0 +1,121 @@ +import { + FormControl, + FormLabel, + HStack, + RangeSlider, + RangeSliderFilledTrack, + RangeSliderMark, + RangeSliderThumb, + RangeSliderTrack, + Tooltip, +} from '@chakra-ui/react'; +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { + setSeamHighThreshold, + setSeamLowThreshold, +} from 'features/parameters/store/generationSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { BiReset } from 'react-icons/bi'; + +const ParamSeamThreshold = () => { + const dispatch = useAppDispatch(); + const seamLowThreshold = useAppSelector( + (state: RootState) => state.generation.seamLowThreshold + ); + + const seamHighThreshold = useAppSelector( + (state: RootState) => state.generation.seamHighThreshold + ); + const { t } = useTranslation(); + + const handleSeamThresholdChange = useCallback( + (v: number[]) => { + dispatch(setSeamLowThreshold(v[0] as number)); + dispatch(setSeamHighThreshold(v[1] as number)); + }, + [dispatch] + ); + + const handleSeamThresholdReset = () => { + dispatch(setSeamLowThreshold(100)); + dispatch(setSeamHighThreshold(200)); + }; + + return ( + + {t('parameters.seamThreshold')} + + + + + + + + + + + + + 0 + + + 100 + + + 200 + + + 255 + + + } + onClick={handleSeamThresholdReset} + /> + + + ); +}; + +export default memo(ParamSeamThreshold); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 0173391833..d8495c5751 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -37,6 +37,12 @@ export interface GenerationState { scheduler: SchedulerParam; maskBlur: number; maskBlurMethod: MaskBlurMethodParam; + seamSize: number; + seamBlur: number; + seamSteps: number; + seamStrength: StrengthParam; + seamLowThreshold: number; + seamHighThreshold: number; seed: SeedParam; seedWeights: string; shouldFitToWidthHeight: boolean; @@ -74,6 +80,12 @@ export const initialGenerationState: GenerationState = { scheduler: 'euler', maskBlur: 16, maskBlurMethod: 'box', + seamSize: 16, + seamBlur: 8, + seamSteps: 20, + seamStrength: 0.7, + seamLowThreshold: 100, + seamHighThreshold: 200, seed: 0, seedWeights: '', shouldFitToWidthHeight: true, @@ -200,6 +212,24 @@ export const generationSlice = createSlice({ setMaskBlurMethod: (state, action: PayloadAction) => { state.maskBlurMethod = action.payload; }, + setSeamSize: (state, action: PayloadAction) => { + state.seamSize = action.payload; + }, + setSeamBlur: (state, action: PayloadAction) => { + state.seamBlur = action.payload; + }, + setSeamSteps: (state, action: PayloadAction) => { + state.seamSteps = action.payload; + }, + setSeamStrength: (state, action: PayloadAction) => { + state.seamStrength = action.payload; + }, + setSeamLowThreshold: (state, action: PayloadAction) => { + state.seamLowThreshold = action.payload; + }, + setSeamHighThreshold: (state, action: PayloadAction) => { + state.seamHighThreshold = action.payload; + }, setTileSize: (state, action: PayloadAction) => { state.tileSize = action.payload; }, @@ -306,6 +336,12 @@ export const { setScheduler, setMaskBlur, setMaskBlurMethod, + setSeamSize, + setSeamBlur, + setSeamSteps, + setSeamStrength, + setSeamLowThreshold, + setSeamHighThreshold, setSeed, setSeedWeights, setShouldFitToWidthHeight, diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx index c6af754ad9..74833ebd70 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLUnifiedCanvasTabParameters.tsx @@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse'; +import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; @@ -22,6 +23,7 @@ export default function SDXLUnifiedCanvasTabParameters() { + ); } diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx index fcfffee48b..9e6dc8fef8 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx @@ -6,6 +6,7 @@ import ParamControlNetCollapse from 'features/parameters/components/Parameters/C import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; // import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse'; +import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse'; import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; @@ -23,6 +24,7 @@ const UnifiedCanvasParameters = () => { + );