Merge branch 'main' into refactor/rename-get-logger

This commit is contained in:
Lincoln Stein 2023-08-17 19:01:17 -04:00
commit 79084e9e20
21 changed files with 1043 additions and 49 deletions

View File

@ -25,10 +25,10 @@ This method is recommended for experienced users and developers
#### [Docker Installation](040_INSTALL_DOCKER.md) #### [Docker Installation](040_INSTALL_DOCKER.md)
This method is recommended for those familiar with running Docker containers This method is recommended for those familiar with running Docker containers
### Other Installation Guides ### Other Installation Guides
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md) - [PyPatchMatch](060_INSTALL_PATCHMATCH.md)
- [XFormers](installation/070_INSTALL_XFORMERS.md) - [XFormers](070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md) - [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](installation/050_INSTALLING_MODELS.md) - [Installing New Models](050_INSTALLING_MODELS.md)
## :fontawesome-solid-computer: Hardware Requirements ## :fontawesome-solid-computer: Hardware Requirements

View File

@ -29,6 +29,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme
writes to the system log is stored in InvocationServices.performance_statistics. writes to the system log is stored in InvocationServices.performance_statistics.
""" """
import psutil
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
@ -42,6 +43,11 @@ import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState from .graph import GraphExecutionState
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from invokeai.backend.model_management.model_cache import CacheStats
# size of GIG in bytes
GIG = 1073741824
class InvocationStatsServiceBase(ABC): class InvocationStatsServiceBase(ABC):
@ -89,6 +95,8 @@ class InvocationStatsServiceBase(ABC):
invocation_type: str, invocation_type: str,
time_used: float, time_used: float,
vram_used: float, vram_used: float,
ram_used: float,
ram_changed: float,
): ):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
@ -97,6 +105,8 @@ class InvocationStatsServiceBase(ABC):
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec) :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB) :param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
pass pass
@ -115,6 +125,9 @@ class NodeStats:
calls: int = 0 calls: int = 0
time_used: float = 0.0 # seconds time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
@dataclass @dataclass
@ -133,31 +146,62 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog} # {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {} self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0
class StatsContext: class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"): """Context manager for collecting statistics."""
invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation self.invocation = invocation
self.collector = collector self.collector = collector
self.graph_id = graph_id self.graph_id = graph_id
self.start_time = 0 self.start_time = 0
self.ram_used = 0
self.model_manager = model_manager
def __enter__(self): def __enter__(self):
self.start_time = time.time() self.start_time = time.time()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args): def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats( self.collector.update_invocation_stats(
self.graph_id, graph_id=self.graph_id,
self.invocation.type, invocation_type=self.invocation.type,
time.time() - self.start_time, time_used=time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
) )
def collect_stats( def collect_stats(
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_execution_state_id: str, graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext: ) -> StatsContext:
""" """
Return a context object that will capture the statistics. Return a context object that will capture the statistics.
@ -166,7 +210,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
""" """
if not self._stats.get(graph_execution_state_id): # first time we're seeing this if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog() self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self) self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
def reset_all_stats(self): def reset_all_stats(self):
"""Zero all statistics""" """Zero all statistics"""
@ -179,13 +224,36 @@ class InvocationStatsService(InvocationStatsServiceBase):
except KeyError: except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float): def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
""" """
Add timing information on execution of a node. Usually Add timing information on execution of a node. Usually
used internally. used internally.
:param graph_id: ID of the graph that is currently executing :param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node :param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection :param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
""" """
if not self._stats[graph_id].nodes.get(invocation_type): if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats() self._stats[graph_id].nodes[invocation_type] = NodeStats()
@ -197,7 +265,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self): def log_stats(self):
""" """
Send the statistics to the system logger at the info level. Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph Stats will only be printed when the execution of the graph
is complete. is complete.
""" """
completed = set() completed = set()
@ -208,16 +276,30 @@ class InvocationStatsService(InvocationStatsServiceBase):
total_time = 0 total_time = 0
logger.info(f"Graph stats: {graph_id}") logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used") logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
for node_type, stats in self._stats[graph_id].nodes.items(): for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G") logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
total_time += stats.time_used total_time += stats.time_used
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
completed.add(graph_id) completed.add(graph_id)
for graph_id in completed: for graph_id in completed:
del self._stats[graph_id] del self._stats[graph_id]
del self._cache_stats[graph_id]

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_management import (
ModelNotFoundException, ModelNotFoundException,
) )
from invokeai.backend.model_management.model_search import FindModels from invokeai.backend.model_management.model_search import FindModels
from invokeai.backend.model_management.model_cache import CacheStats
import torch import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
@ -276,6 +277,13 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None: def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
@ -500,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase):
self.logger.debug(f"convert model {model_name}") self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None): def commit(self, conf_file: Optional[Path] = None):
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.

View File

@ -86,7 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke # Invoke
try: try:
with statistics.collect_stats(invocation, graph_execution_state.id): graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method in # use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a # this accomodates nodes which require a value, but get it only from a
# connection # connection

View File

@ -21,12 +21,12 @@ import os
import sys import sys
import hashlib import hashlib
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any from typing import Dict, Union, types, Optional, Type, Any
import torch import torch
import logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, SubModelType, ModelBase from .models import BaseModelType, ModelType, SubModelType, ModelBase
@ -41,6 +41,18 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
GIG = 1073741824 GIG = 1073741824
@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object): class ModelLocker(object):
"Forward declaration" "Forward declaration"
pass pass
@ -115,6 +127,9 @@ class ModelCache(object):
self.sha_chunksize = sha_chunksize self.sha_chunksize = sha_chunksize
self.logger = logger self.logger = logger
# used for stats collection
self.stats = None
self._cached_models = dict() self._cached_models = dict()
self._cache_stack = list() self._cache_stack = list()
@ -181,13 +196,14 @@ class ModelCache(object):
model_type=model_type, model_type=model_type,
submodel_type=submodel, submodel_type=submodel,
) )
# TODO: lock for no copies on simultaneous calls? # TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info( self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}" f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
) )
if self.stats:
self.stats.misses += 1
# this will remove older cached models until # this will remove older cached models until
# there is sufficient room to load the requested model # there is sufficient room to load the requested model
@ -201,6 +217,17 @@ class ModelCache(object):
cache_entry = _CacheRecord(self, model, mem_used) cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)
with suppress(Exception): with suppress(Exception):
self._cache_stack.remove(key) self._cache_stack.remove(key)
@ -280,14 +307,14 @@ class ModelCache(object):
""" """
Given the HF repo id or path to a model on disk, returns a unique Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk. :param model_path: Path to model file/directory on disk.
""" """
return self._local_model_hash(model_path) return self._local_model_hash(model_path)
def cache_size(self) -> float: def cache_size(self) -> float:
"Return the current size of the cache, in GB" """Return the current size of the cache, in GB."""
current_cache_size = sum([m.size for m in self._cached_models.values()]) return self._cache_size() / GIG
return current_cache_size / GIG
def _has_cuda(self) -> bool: def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda" return self.execution_device.type == "cuda"
@ -310,12 +337,15 @@ class ModelCache(object):
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}" f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
) )
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
def _make_cache_room(self, model_size): def _make_cache_room(self, model_size):
# calculate how much memory this model will require # calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1 # multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()]) current_size = self._cache_size()
if current_size + bytes_needed > maximum_size: if current_size + bytes_needed > maximum_size:
self.logger.debug( self.logger.debug(
@ -364,6 +394,8 @@ class ModelCache(object):
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )
current_size -= cache_entry.size current_size -= cache_entry.size
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos] del self._cache_stack[pos]
del self._cached_models[model_key] del self._cached_models[model_key]
del cache_entry del cache_entry

View File

@ -240,6 +240,7 @@ class InvokeAIDiffuserComponent:
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False, return_dict=False,
) )

View File

@ -4,8 +4,15 @@ import torch
from torch import nn from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
@ -18,10 +25,11 @@ from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
# TODO: create PR to diffusers
# Modified ControlNetModel with encoder_attention_mask argument added # Modified ControlNetModel with encoder_attention_mask argument added
class ControlNetModel(ModelMixin, ConfigMixin): class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
""" """
A ControlNet model. A ControlNet model.
@ -52,12 +60,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The epsilon to use for the normalization. The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280): cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`): use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0): num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
@ -98,10 +119,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
@ -109,6 +135,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb", controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False, global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
): ):
super().__init__() super().__init__()
@ -136,6 +163,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
@ -145,16 +175,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# time # time
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
timestep_input_dim, timestep_input_dim,
time_embed_dim, time_embed_dim,
act_fn=act_fn, act_fn=act_fn,
) )
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@ -178,6 +235,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
else: else:
self.class_embedding = None self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding # control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0], conditioning_embedding_channels=block_out_channels[0],
@ -212,6 +292,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=layers_per_block, num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
@ -248,6 +329,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
self.controlnet_mid_block = controlnet_block self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn( self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel, in_channels=mid_block_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=norm_eps, resnet_eps=norm_eps,
@ -277,7 +359,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable. where applicable.
""" """
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls( controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels, in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos, flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift, freq_shift=unet.config.freq_shift,
@ -463,6 +560,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
guess_mode: bool = False, guess_mode: bool = False,
@ -486,7 +584,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
encoder_attention_mask (`torch.Tensor`): encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
@ -549,6 +649,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=sample.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
@ -560,11 +661,34 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb emb = emb + class_emb
if "addition_embed_type" in self.config:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond sample = sample + controlnet_cond
# 3. down # 3. down

View File

@ -506,10 +506,14 @@
"maskAdjustmentsHeader": "Mask Adjustments", "maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur", "maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method", "maskBlurMethod": "Mask Blur Method",
"seamPaintingHeader": "Seam Painting",
"seamSize": "Seam Size", "seamSize": "Seam Size",
"seamBlur": "Seam Blur", "seamBlur": "Seam Blur",
"seamStrength": "Seam Strength",
"seamSteps": "Seam Steps", "seamSteps": "Seam Steps",
"seamStrength": "Seam Strength",
"seamThreshold": "Seam Threshold",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing", "scaleBeforeProcessing": "Scale Before Processing",
"scaledWidth": "Scaled W", "scaledWidth": "Scaled W",
"scaledHeight": "Scaled H", "scaledHeight": "Scaled H",

View File

@ -121,7 +121,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
effect: async (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload; const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length < 1 || imagesUsage.length < 1) { if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
// handle singles in separate listener // handle singles in separate listener
return; return;
} }

View File

@ -32,6 +32,7 @@ import {
MAIN_MODEL_LOADER, MAIN_MODEL_LOADER,
MASK_BLUR, MASK_BLUR,
MASK_COMBINE, MASK_COMBINE,
MASK_EDGE,
MASK_FROM_ALPHA, MASK_FROM_ALPHA,
MASK_RESIZE_DOWN, MASK_RESIZE_DOWN,
MASK_RESIZE_UP, MASK_RESIZE_UP,
@ -40,6 +41,10 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
SEAM_FIX_DENOISE_LATENTS,
SEAM_MASK_COMBINE,
SEAM_MASK_RESIZE_DOWN,
SEAM_MASK_RESIZE_UP,
} from './constants'; } from './constants';
/** /**
@ -67,6 +72,12 @@ export const buildCanvasOutpaintGraph = (
shouldUseCpuNoise, shouldUseCpuNoise,
maskBlur, maskBlur,
maskBlurMethod, maskBlurMethod,
seamSize,
seamBlur,
seamSteps,
seamStrength,
seamLowThreshold,
seamHighThreshold,
tileSize, tileSize,
infillMethod, infillMethod,
clipSkip, clipSkip,
@ -130,6 +141,11 @@ export const buildCanvasOutpaintGraph = (
is_intermediate: true, is_intermediate: true,
mask2: canvasMaskImage, mask2: canvasMaskImage,
}, },
[SEAM_MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
},
[MASK_BLUR]: { [MASK_BLUR]: {
type: 'img_blur', type: 'img_blur',
id: MASK_BLUR, id: MASK_BLUR,
@ -165,6 +181,25 @@ export const buildCanvasOutpaintGraph = (
denoising_start: 1 - strength, denoising_start: 1 - strength,
denoising_end: 1, denoising_end: 1,
}, },
[MASK_EDGE]: {
type: 'mask_edge',
id: MASK_EDGE,
is_intermediate: true,
edge_size: seamSize,
edge_blur: seamBlur,
low_threshold: seamLowThreshold,
high_threshold: seamHighThreshold,
},
[SEAM_FIX_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SEAM_FIX_DENOISE_LATENTS,
is_intermediate: true,
steps: seamSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - seamStrength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
@ -333,12 +368,63 @@ export const buildCanvasOutpaintGraph = (
field: 'seed', field: 'seed',
}, },
}, },
// Decode the result from Inpaint // Seam Paint
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'noise',
},
},
{ {
source: { source: {
node_id: DENOISE_LATENTS, node_id: DENOISE_LATENTS,
field: 'latents', field: 'latents',
}, },
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
},
// Decode the result from Inpaint
{
source: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: LATENTS_TO_IMAGE,
field: 'latents', field: 'latents',
@ -348,7 +434,6 @@ export const buildCanvasOutpaintGraph = (
}; };
// Add Infill Nodes // Add Infill Nodes
if (infillMethod === 'patchmatch') { if (infillMethod === 'patchmatch') {
graph.nodes[INPAINT_INFILL] = { graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch', type: 'infill_patchmatch',
@ -378,6 +463,13 @@ export const buildCanvasOutpaintGraph = (
width: scaledWidth, width: scaledWidth,
height: scaledHeight, height: scaledHeight,
}; };
graph.nodes[SEAM_MASK_RESIZE_UP] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = { graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize', type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN, id: INPAINT_IMAGE_RESIZE_DOWN,
@ -399,6 +491,13 @@ export const buildCanvasOutpaintGraph = (
width: width, width: width,
height: height, height: height,
}; };
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = { graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation), ...(graph.nodes[NOISE] as NoiseInvocation),
@ -440,6 +539,57 @@ export const buildCanvasOutpaintGraph = (
field: 'image', field: 'image',
}, },
}, },
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Resize Results Down // Resize Results Down
{ {
source: { source: {
@ -453,7 +603,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_BLUR, node_id: MASK_RESIZE_UP,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -461,6 +611,16 @@ export const buildCanvasOutpaintGraph = (
field: 'image', field: 'image',
}, },
}, },
{
source: {
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
},
{ {
source: { source: {
node_id: INPAINT_INFILL, node_id: INPAINT_INFILL,
@ -494,7 +654,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_DOWN, node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -525,7 +685,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_DOWN, node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -553,7 +713,6 @@ export const buildCanvasOutpaintGraph = (
}; };
graph.nodes[MASK_BLUR] = { graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation), ...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
}; };
graph.edges.push( graph.edges.push(
@ -568,6 +727,47 @@ export const buildCanvasOutpaintGraph = (
field: 'image', field: 'image',
}, },
}, },
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Color Correct The Inpainted Result // Color Correct The Inpainted Result
{ {
source: { source: {
@ -591,7 +791,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_BLUR, node_id: SEAM_MASK_COMBINE,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -622,7 +822,7 @@ export const buildCanvasOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_BLUR, node_id: SEAM_MASK_COMBINE,
field: 'image', field: 'image',
}, },
destination: { destination: {

View File

@ -29,6 +29,7 @@ import {
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_BLUR, MASK_BLUR,
MASK_COMBINE, MASK_COMBINE,
MASK_EDGE,
MASK_FROM_ALPHA, MASK_FROM_ALPHA,
MASK_RESIZE_DOWN, MASK_RESIZE_DOWN,
MASK_RESIZE_UP, MASK_RESIZE_UP,
@ -40,6 +41,10 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SEAM_FIX_DENOISE_LATENTS,
SEAM_MASK_COMBINE,
SEAM_MASK_RESIZE_DOWN,
SEAM_MASK_RESIZE_UP,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -67,6 +72,12 @@ export const buildCanvasSDXLOutpaintGraph = (
shouldUseCpuNoise, shouldUseCpuNoise,
maskBlur, maskBlur,
maskBlurMethod, maskBlurMethod,
seamSize,
seamBlur,
seamSteps,
seamStrength,
seamLowThreshold,
seamHighThreshold,
tileSize, tileSize,
infillMethod, infillMethod,
} = state.generation; } = state.generation;
@ -133,6 +144,11 @@ export const buildCanvasSDXLOutpaintGraph = (
is_intermediate: true, is_intermediate: true,
mask2: canvasMaskImage, mask2: canvasMaskImage,
}, },
[SEAM_MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate: true,
},
[MASK_BLUR]: { [MASK_BLUR]: {
type: 'img_blur', type: 'img_blur',
id: MASK_BLUR, id: MASK_BLUR,
@ -170,6 +186,25 @@ export const buildCanvasSDXLOutpaintGraph = (
: 1 - strength, : 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
}, },
[MASK_EDGE]: {
type: 'mask_edge',
id: MASK_EDGE,
is_intermediate: true,
edge_size: seamSize,
edge_blur: seamBlur,
low_threshold: seamLowThreshold,
high_threshold: seamHighThreshold,
},
[SEAM_FIX_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SEAM_FIX_DENOISE_LATENTS,
is_intermediate: true,
steps: seamSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - seamStrength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
@ -347,12 +382,63 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'seed', field: 'seed',
}, },
}, },
// Decode inpainted latents to image // Seam Paint
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'noise',
},
},
{ {
source: { source: {
node_id: SDXL_DENOISE_LATENTS, node_id: SDXL_DENOISE_LATENTS,
field: 'latents', field: 'latents',
}, },
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
},
// Decode inpainted latents to image
{
source: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'latents',
},
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: LATENTS_TO_IMAGE,
field: 'latents', field: 'latents',
@ -392,6 +478,13 @@ export const buildCanvasSDXLOutpaintGraph = (
width: scaledWidth, width: scaledWidth,
height: scaledHeight, height: scaledHeight,
}; };
graph.nodes[SEAM_MASK_RESIZE_UP] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = { graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize', type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN, id: INPAINT_IMAGE_RESIZE_DOWN,
@ -413,6 +506,13 @@ export const buildCanvasSDXLOutpaintGraph = (
width: width, width: width,
height: height, height: height,
}; };
graph.nodes[SEAM_MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: SEAM_MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = { graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation), ...(graph.nodes[NOISE] as NoiseInvocation),
@ -454,6 +554,57 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image', field: 'image',
}, },
}, },
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: SEAM_MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Resize Results Down // Resize Results Down
{ {
source: { source: {
@ -467,7 +618,7 @@ export const buildCanvasSDXLOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_BLUR, node_id: MASK_RESIZE_UP,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -475,6 +626,16 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image', field: 'image',
}, },
}, },
{
source: {
node_id: SEAM_MASK_COMBINE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image',
},
},
{ {
source: { source: {
node_id: INPAINT_INFILL, node_id: INPAINT_INFILL,
@ -508,7 +669,7 @@ export const buildCanvasSDXLOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_DOWN, node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -539,7 +700,7 @@ export const buildCanvasSDXLOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_RESIZE_DOWN, node_id: SEAM_MASK_RESIZE_DOWN,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -567,7 +728,6 @@ export const buildCanvasSDXLOutpaintGraph = (
}; };
graph.nodes[MASK_BLUR] = { graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation), ...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
}; };
graph.edges.push( graph.edges.push(
@ -582,6 +742,47 @@ export const buildCanvasSDXLOutpaintGraph = (
field: 'image', field: 'image',
}, },
}, },
// Seam Paint Mask
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: MASK_EDGE,
field: 'image',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_FIX_DENOISE_LATENTS,
field: 'mask',
},
},
{
source: {
node_id: MASK_FROM_ALPHA,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask1',
},
},
{
source: {
node_id: MASK_EDGE,
field: 'image',
},
destination: {
node_id: SEAM_MASK_COMBINE,
field: 'mask2',
},
},
// Color Correct The Inpainted Result // Color Correct The Inpainted Result
{ {
source: { source: {
@ -605,7 +806,7 @@ export const buildCanvasSDXLOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_BLUR, node_id: SEAM_MASK_COMBINE,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -636,7 +837,7 @@ export const buildCanvasSDXLOutpaintGraph = (
}, },
{ {
source: { source: {
node_id: MASK_BLUR, node_id: SEAM_MASK_COMBINE,
field: 'image', field: 'image',
}, },
destination: { destination: {
@ -669,7 +870,7 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SEAM_FIX_DENOISE_LATENTS);
} }
// optionally add custom VAE // optionally add custom VAE

View File

@ -18,8 +18,6 @@ export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';
export const CANVAS_OUTPUT = 'canvas_output'; export const CANVAS_OUTPUT = 'canvas_output';
export const INPAINT = 'inpaint';
export const INPAINT_SEAM_FIX = 'inpaint_seam_fix';
export const INPAINT_IMAGE = 'inpaint_image'; export const INPAINT_IMAGE = 'inpaint_image';
export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image'; export const SCALED_INPAINT_IMAGE = 'scaled_inpaint_image';
export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up'; export const INPAINT_IMAGE_RESIZE_UP = 'inpaint_image_resize_up';
@ -27,10 +25,14 @@ export const INPAINT_IMAGE_RESIZE_DOWN = 'inpaint_image_resize_down';
export const INPAINT_INFILL = 'inpaint_infill'; export const INPAINT_INFILL = 'inpaint_infill';
export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down'; export const INPAINT_INFILL_RESIZE_DOWN = 'inpaint_infill_resize_down';
export const INPAINT_FINAL_IMAGE = 'inpaint_final_image'; export const INPAINT_FINAL_IMAGE = 'inpaint_final_image';
export const SEAM_FIX_DENOISE_LATENTS = 'seam_fix_denoise_latents';
export const MASK_FROM_ALPHA = 'tomask'; export const MASK_FROM_ALPHA = 'tomask';
export const MASK_EDGE = 'mask_edge'; export const MASK_EDGE = 'mask_edge';
export const MASK_BLUR = 'mask_blur'; export const MASK_BLUR = 'mask_blur';
export const MASK_COMBINE = 'mask_combine'; export const MASK_COMBINE = 'mask_combine';
export const SEAM_MASK_COMBINE = 'seam_mask_combine';
export const SEAM_MASK_RESIZE_UP = 'seam_mask_resize_up';
export const SEAM_MASK_RESIZE_DOWN = 'seam_mask_resize_down';
export const MASK_RESIZE_UP = 'mask_resize_up'; export const MASK_RESIZE_UP = 'mask_resize_up';
export const MASK_RESIZE_DOWN = 'mask_resize_down'; export const MASK_RESIZE_DOWN = 'mask_resize_down';
export const COLOR_CORRECT = 'color_correct'; export const COLOR_CORRECT = 'color_correct';

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamBlur } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamBlur = () => {
const dispatch = useAppDispatch();
const seamBlur = useAppSelector(
(state: RootState) => state.generation.seamBlur
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamBlur')}
min={0}
max={64}
step={8}
sliderNumberInputProps={{ max: 512 }}
value={seamBlur}
onChange={(v) => {
dispatch(setSeamBlur(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamBlur(8));
}}
/>
);
};
export default memo(ParamSeamBlur);

View File

@ -0,0 +1,27 @@
import { Flex } from '@chakra-ui/react';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamSeamBlur from './ParamSeamBlur';
import ParamSeamSize from './ParamSeamSize';
import ParamSeamSteps from './ParamSeamSteps';
import ParamSeamStrength from './ParamSeamStrength';
import ParamSeamThreshold from './ParamSeamThreshold';
const ParamSeamPaintingCollapse = () => {
const { t } = useTranslation();
return (
<IAICollapse label={t('parameters.seamPaintingHeader')}>
<Flex sx={{ flexDirection: 'column', gap: 2, paddingBottom: 2 }}>
<ParamSeamSize />
<ParamSeamBlur />
<ParamSeamSteps />
<ParamSeamStrength />
<ParamSeamThreshold />
</Flex>
</IAICollapse>
);
};
export default memo(ParamSeamPaintingCollapse);

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamSize } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamSize = () => {
const dispatch = useAppDispatch();
const seamSize = useAppSelector(
(state: RootState) => state.generation.seamSize
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamSize')}
min={0}
max={128}
step={8}
sliderNumberInputProps={{ max: 512 }}
value={seamSize}
onChange={(v) => {
dispatch(setSeamSize(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamSize(16));
}}
/>
);
};
export default memo(ParamSeamSize);

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamSteps } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamSteps = () => {
const dispatch = useAppDispatch();
const seamSteps = useAppSelector(
(state: RootState) => state.generation.seamSteps
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamSteps')}
min={0}
max={100}
step={1}
sliderNumberInputProps={{ max: 999 }}
value={seamSteps}
onChange={(v) => {
dispatch(setSeamSteps(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamSteps(20));
}}
/>
);
};
export default memo(ParamSeamSteps);

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setSeamStrength } from 'features/parameters/store/generationSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamSeamStrength = () => {
const dispatch = useAppDispatch();
const seamStrength = useAppSelector(
(state: RootState) => state.generation.seamStrength
);
const { t } = useTranslation();
return (
<IAISlider
label={t('parameters.seamStrength')}
min={0}
max={1}
step={0.01}
sliderNumberInputProps={{ max: 999 }}
value={seamStrength}
onChange={(v) => {
dispatch(setSeamStrength(v));
}}
withInput
withSliderMarks
withReset
handleReset={() => {
dispatch(setSeamStrength(0.7));
}}
/>
);
};
export default memo(ParamSeamStrength);

View File

@ -0,0 +1,121 @@
import {
FormControl,
FormLabel,
HStack,
RangeSlider,
RangeSliderFilledTrack,
RangeSliderMark,
RangeSliderThumb,
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
setSeamHighThreshold,
setSeamLowThreshold,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
const ParamSeamThreshold = () => {
const dispatch = useAppDispatch();
const seamLowThreshold = useAppSelector(
(state: RootState) => state.generation.seamLowThreshold
);
const seamHighThreshold = useAppSelector(
(state: RootState) => state.generation.seamHighThreshold
);
const { t } = useTranslation();
const handleSeamThresholdChange = useCallback(
(v: number[]) => {
dispatch(setSeamLowThreshold(v[0] as number));
dispatch(setSeamHighThreshold(v[1] as number));
},
[dispatch]
);
const handleSeamThresholdReset = () => {
dispatch(setSeamLowThreshold(100));
dispatch(setSeamHighThreshold(200));
};
return (
<FormControl>
<FormLabel>{t('parameters.seamThreshold')}</FormLabel>
<HStack w="100%" gap={4} mt={-2}>
<RangeSlider
aria-label={[
t('parameters.seamLowThreshold'),
t('parameters.seamHighThreshold'),
]}
value={[seamLowThreshold, seamHighThreshold]}
min={0}
max={255}
step={1}
minStepsBetweenThumbs={1}
onChange={handleSeamThresholdChange}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={seamLowThreshold} placement="top" hasArrow>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={seamHighThreshold} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0
</RangeSliderMark>
<RangeSliderMark
value={0.392}
sx={{
insetInlineStart: '38.4% !important',
transform: 'translateX(-38.4%)',
}}
>
100
</RangeSliderMark>
<RangeSliderMark
value={0.784}
sx={{
insetInlineStart: '79.8% !important',
transform: 'translateX(-79.8%)',
}}
>
200
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
255
</RangeSliderMark>
</RangeSlider>
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleSeamThresholdReset}
/>
</HStack>
</FormControl>
);
};
export default memo(ParamSeamThreshold);

View File

@ -37,6 +37,12 @@ export interface GenerationState {
scheduler: SchedulerParam; scheduler: SchedulerParam;
maskBlur: number; maskBlur: number;
maskBlurMethod: MaskBlurMethodParam; maskBlurMethod: MaskBlurMethodParam;
seamSize: number;
seamBlur: number;
seamSteps: number;
seamStrength: StrengthParam;
seamLowThreshold: number;
seamHighThreshold: number;
seed: SeedParam; seed: SeedParam;
seedWeights: string; seedWeights: string;
shouldFitToWidthHeight: boolean; shouldFitToWidthHeight: boolean;
@ -74,6 +80,12 @@ export const initialGenerationState: GenerationState = {
scheduler: 'euler', scheduler: 'euler',
maskBlur: 16, maskBlur: 16,
maskBlurMethod: 'box', maskBlurMethod: 'box',
seamSize: 16,
seamBlur: 8,
seamSteps: 20,
seamStrength: 0.7,
seamLowThreshold: 100,
seamHighThreshold: 200,
seed: 0, seed: 0,
seedWeights: '', seedWeights: '',
shouldFitToWidthHeight: true, shouldFitToWidthHeight: true,
@ -200,6 +212,24 @@ export const generationSlice = createSlice({
setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => { setMaskBlurMethod: (state, action: PayloadAction<MaskBlurMethodParam>) => {
state.maskBlurMethod = action.payload; state.maskBlurMethod = action.payload;
}, },
setSeamSize: (state, action: PayloadAction<number>) => {
state.seamSize = action.payload;
},
setSeamBlur: (state, action: PayloadAction<number>) => {
state.seamBlur = action.payload;
},
setSeamSteps: (state, action: PayloadAction<number>) => {
state.seamSteps = action.payload;
},
setSeamStrength: (state, action: PayloadAction<number>) => {
state.seamStrength = action.payload;
},
setSeamLowThreshold: (state, action: PayloadAction<number>) => {
state.seamLowThreshold = action.payload;
},
setSeamHighThreshold: (state, action: PayloadAction<number>) => {
state.seamHighThreshold = action.payload;
},
setTileSize: (state, action: PayloadAction<number>) => { setTileSize: (state, action: PayloadAction<number>) => {
state.tileSize = action.payload; state.tileSize = action.payload;
}, },
@ -306,6 +336,12 @@ export const {
setScheduler, setScheduler,
setMaskBlur, setMaskBlur,
setMaskBlurMethod, setMaskBlurMethod,
setSeamSize,
setSeamBlur,
setSeamSteps,
setSeamStrength,
setSeamLowThreshold,
setSeamHighThreshold,
setSeed, setSeed,
setSeedWeights, setSeedWeights,
setShouldFitToWidthHeight, setShouldFitToWidthHeight,

View File

@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse'; import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
@ -22,6 +23,7 @@ export default function SDXLUnifiedCanvasTabParameters() {
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamMaskAdjustmentCollapse /> <ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse /> <ParamInfillAndScalingCollapse />
<ParamSeamPaintingCollapse />
</> </>
); );
} }

View File

@ -6,6 +6,7 @@ import ParamControlNetCollapse from 'features/parameters/components/Parameters/C
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; // import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse'; import ParamMaskAdjustmentCollapse from 'features/parameters/components/Parameters/Canvas/MaskAdjustment/ParamMaskAdjustmentCollapse';
import ParamSeamPaintingCollapse from 'features/parameters/components/Parameters/Canvas/SeamPainting/ParamSeamPaintingCollapse';
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea'; import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
@ -23,6 +24,7 @@ const UnifiedCanvasParameters = () => {
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />
<ParamMaskAdjustmentCollapse /> <ParamMaskAdjustmentCollapse />
<ParamInfillAndScalingCollapse /> <ParamInfillAndScalingCollapse />
<ParamSeamPaintingCollapse />
<ParamAdvancedCollapse /> <ParamAdvancedCollapse />
</> </>
); );