mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add IP Adapter Style & Composition Modes (#6213)
## Summary Until now IP Adapter had complete control on the contents of the output. With this PR, users are now able to select "Style Only" or "Composition Only" to draw just the style or layout of the reference image. Based off: https://arxiv.org/abs/2404.02733 ### New IP Method Option - `Full` - Both style and layout of the refence image are used. - `Style Only` - Only the style of the image is used - `Composition Only` - Only the composition of the image is used. ![opera_0BkqZTwObO](https://github.com/invoke-ai/InvokeAI/assets/54517381/1b2fbbba-44c9-4c25-87cb-3711a17d13e3) ### Example Result ![demo](https://github.com/invoke-ai/InvokeAI/assets/54517381/703f3de5-e685-4691-acda-9338a4c10796) ### Notes - Supports both SDXL and SD1.5 ### Testing - Just check and test if it works as expected with all IP Adapter models - both SDXL and SD1.5 ## Merge Plan Good to merge once tested for all edge cases.
This commit is contained in:
commit
c828a4e59f
@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
@ -69,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
method: Literal["full", "style", "composition"] = InputField(
|
||||
default="full", description="The method to apply the IP-Adapter"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
@ -124,12 +116,32 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
|
||||
if self.method == "style":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["up_blocks.1"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["up_blocks.0.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "composition":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["down_blocks.2", "mid_block"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["down_blocks.2.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "full":
|
||||
target_blocks = ["block"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
|
||||
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||
weight=self.weight,
|
||||
target_blocks=target_blocks,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
mask=self.mask,
|
||||
|
@ -679,6 +679,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=single_ip_adapter.weight,
|
||||
target_blocks=single_ip_adapter.target_blocks,
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
|
@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
@ -21,12 +21,9 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterData,
|
||||
TextConditioningData,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
|
@ -53,6 +53,7 @@ class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
mask: torch.Tensor
|
||||
target_blocks: List[str]
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterAttentionWeights:
|
||||
ip_adapter_weights: IPAttentionProcessorWeights
|
||||
skip: bool
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
This implementation is based on
|
||||
@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
for the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
self._ip_adapter_weights = ip_adapter_weights
|
||||
|
||||
def _is_ip_adapter_enabled(self) -> bool:
|
||||
return self._ip_adapter_weights is not None
|
||||
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
# For regional prompting:
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
# For Regional Prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
percent_through: Optional[torch.Tensor] = None,
|
||||
# For IP-Adapter:
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention:
|
||||
if self._is_ip_adapter_enabled():
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_weights)
|
||||
== len(self._ip_adapter_attention_weights)
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_weights[ipa_index]
|
||||
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
ip_mask = ip_masks[0, ipa_index, ...]
|
||||
|
||||
@ -153,29 +161,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End of unmodified block from AttnProcessor2_0
|
||||
|
||||
return hidden_states
|
||||
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||
return cast(torch.FloatTensor, hidden_states)
|
||||
|
@ -1,17 +1,25 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import List, Optional, TypedDict
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||
CustomAttnProcessor2_0,
|
||||
IPAdapterAttentionWeights,
|
||||
)
|
||||
|
||||
|
||||
class UNetIPAdapterData(TypedDict):
|
||||
ip_adapter: IPAdapter
|
||||
target_blocks: List[str]
|
||||
|
||||
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
|
||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||
self._ip_adapters = ip_adapters
|
||||
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||
self._ip_adapters = ip_adapter_data
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
@ -26,9 +34,22 @@ class UNetAttentionPatcher:
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = CustomAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
)
|
||||
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||
|
||||
for ip_adapter in self._ip_adapters:
|
||||
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||
skip = True
|
||||
for block in ip_adapter["target_blocks"]:
|
||||
if block in name:
|
||||
skip = False
|
||||
break
|
||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||
)
|
||||
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||
|
||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
|
@ -213,6 +213,10 @@
|
||||
"resize": "Resize",
|
||||
"resizeSimple": "Resize (Simple)",
|
||||
"resizeMode": "Resize Mode",
|
||||
"ipAdapterMethod": "Method",
|
||||
"full": "Full",
|
||||
"style": "Style Only",
|
||||
"composition": "Composition Only",
|
||||
"safe": "Safe",
|
||||
"saveControlImage": "Save Control Image",
|
||||
"scribble": "scribble",
|
||||
|
@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
|
||||
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
||||
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
||||
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
|
||||
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
||||
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
||||
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
||||
@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
|
||||
|
||||
<Flex w="full" flexDir="column" gap={4}>
|
||||
<Flex gap={8} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} h={32} w="full">
|
||||
<Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
|
||||
<ParamControlAdapterIPMethod id={id} />
|
||||
<ParamControlAdapterWeight id={id} />
|
||||
<ParamControlAdapterBeginEnd id={id} />
|
||||
</Flex>
|
||||
|
@ -0,0 +1,63 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { IPMethod } from 'features/controlAdapters/store/types';
|
||||
import { isIPMethod } from 'features/controlAdapters/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const ParamControlAdapterIPMethod = ({ id }: Props) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const method = useControlAdapterIPMethod(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const options: { label: string; value: IPMethod }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlnet.full'), value: 'full' },
|
||||
{ label: t('controlnet.style'), value: 'style' },
|
||||
{ label: t('controlnet.composition'), value: 'composition' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isIPMethod(v?.value)) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterIPMethodChanged({
|
||||
id,
|
||||
method: v.value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[id, dispatch]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
|
||||
|
||||
if (!method) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="controlNetResizeMode">
|
||||
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlAdapterIPMethod);
|
@ -0,0 +1,24 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectControlAdapterById,
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useControlAdapterIPMethod = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const cn = selectControlAdapterById(controlAdapters, id);
|
||||
if (cn && cn?.type === 'ip_adapter') {
|
||||
return cn.method;
|
||||
}
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
const method = useAppSelector(selector);
|
||||
|
||||
return method;
|
||||
};
|
@ -21,6 +21,7 @@ import type {
|
||||
ControlAdapterType,
|
||||
ControlMode,
|
||||
ControlNetConfig,
|
||||
IPMethod,
|
||||
RequiredControlAdapterProcessorNode,
|
||||
ResizeMode,
|
||||
T2IAdapterConfig,
|
||||
@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
|
||||
}
|
||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||
},
|
||||
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
|
||||
const { id, method } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { method } });
|
||||
},
|
||||
controlAdapterCLIPVisionModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||
@ -390,6 +395,7 @@ export const {
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterIPMethodChanged,
|
||||
controlAdapterWeightChanged,
|
||||
controlAdapterBeginStepPctChanged,
|
||||
controlAdapterEndStepPctChanged,
|
||||
|
@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
|
||||
export type ResizeMode = z.infer<typeof zResizeMode>;
|
||||
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
||||
|
||||
const zIPMethod = z.enum(['full', 'style', 'composition']);
|
||||
export type IPMethod = z.infer<typeof zIPMethod>;
|
||||
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
|
||||
|
||||
export type ControlNetConfig = {
|
||||
type: 'controlnet';
|
||||
id: string;
|
||||
@ -253,6 +257,7 @@ export type IPAdapterConfig = {
|
||||
model: ParameterIPAdapterModel | null;
|
||||
clipVisionModel: CLIPVisionModel;
|
||||
weight: number;
|
||||
method: IPMethod;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
};
|
||||
|
@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
||||
isEnabled: true,
|
||||
controlImage: null,
|
||||
model: null,
|
||||
method: 'full',
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
|
@ -386,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const method = zIPAdapterField.shape.method
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'method'));
|
||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
@ -403,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
||||
clipVisionModel: 'ViT-H',
|
||||
controlImage: image?.image_name ?? null,
|
||||
weight: weight ?? initialIPAdapter.weight,
|
||||
method: method ?? initialIPAdapter.method,
|
||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||
};
|
||||
|
@ -109,6 +109,7 @@ export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: zModelIdentifierField,
|
||||
weight: z.number(),
|
||||
method: z.enum(['full', 'style', 'composition']),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
});
|
||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
|
||||
assert(controlImage, 'IP Adapter image is required');
|
||||
|
||||
@ -57,6 +57,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
type: 'ip_adapter',
|
||||
is_intermediate: true,
|
||||
weight: weight,
|
||||
method: method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginStepPct,
|
||||
@ -84,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
};
|
||||
|
||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
|
||||
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
@ -102,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
weight,
|
||||
method,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user