Merge branch 'lstein/feat/multi-gpu' of github.com:invoke-ai/InvokeAI into lstein/feat/multi-gpu

This commit is contained in:
Lincoln Stein 2024-04-16 16:27:08 -04:00
commit 89f8326c0b
19 changed files with 377 additions and 70 deletions

View File

@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation, from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.") ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.") image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
) )
@ -69,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField( weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight" default=1, description="The weight given to the IP-Adapter", title="Weight"
) )
method: Literal["full", "style", "composition"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField( begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
) )
@ -124,12 +116,32 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
if self.method == "style":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "composition":
if ip_adapter_info.base == "sd-1":
target_blocks = ["down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=self.ip_adapter_model, ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model), image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight, weight=self.weight,
target_blocks=target_blocks,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
mask=self.mask, mask=self.mask,

View File

@ -679,6 +679,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
IPAdapterData( IPAdapterData(
ip_adapter_model=ip_adapter_model, ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight, weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent, begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent, end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),

View File

@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.") ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model") clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter") weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)") begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@ -754,6 +754,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache[download_job.source] = install_job # matches a download job to an install job self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job) install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job( self._download_queue.submit_download_job(
download_job, download_job,
on_start=self._download_started_callback, on_start=self._download_started_callback,
@ -762,6 +764,7 @@ class ModelInstallService(ModelInstallServiceBase):
on_error=self._download_error_callback, on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback, on_cancelled=self._download_cancelled_callback,
) )
return install_job return install_job
def _stat_size(self, path: Path) -> int: def _stat_size(self, path: Path) -> int:

View File

@ -21,12 +21,9 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
IPAdapterData,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
unet_attention_patcher = None unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext() attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting: if use_ip_adapter or use_regional_prompting:
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters) unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)

View File

@ -53,6 +53,7 @@ class IPAdapterData:
ip_adapter_model: IPAdapter ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step. # Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0 weight: Union[float, List[float]] = 1.0

View File

@ -1,4 +1,5 @@
from typing import Optional from dataclasses import dataclass
from typing import List, Optional, cast
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
class CustomAttnProcessor2_0(AttnProcessor2_0): class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features. """A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on This implementation is based on
@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __init__( def __init__(
self, self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None, ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
): ):
"""Initialize a CustomAttnProcessor2_0. """Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
for the i'th IP-Adapter. for the i'th IP-Adapter.
""" """
super().__init__() super().__init__()
self._ip_adapter_weights = ip_adapter_weights self._ip_adapter_attention_weights = ip_adapter_attention_weights
def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
# For regional prompting: # For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None, regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None, percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter: # For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None, regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
"""Apply attention. """Apply attention.
Args: Args:
@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Apply IP-Adapter conditioning. # Apply IP-Adapter conditioning.
if is_cross_attention: if is_cross_attention:
if self._is_ip_adapter_enabled(): if self._ip_adapter_attention_weights:
assert regional_ip_data is not None assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert ( assert (
len(regional_ip_data.image_prompt_embeds) len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_weights) == len(self._ip_adapter_attention_weights)
== len(regional_ip_data.scales) == len(regional_ip_data.scales)
== ip_masks.shape[1] == ip_masks.shape[1]
) )
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_weights[ipa_index] ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
ipa_scale = regional_ip_data.scales[ipa_index] ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...] ip_mask = ip_masks[0, ipa_index, ...]
@ -153,15 +161,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if not self._ip_adapter_attention_weights[ipa_index].skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) # Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) # Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention( ip_hidden_states = F.scaled_dot_product_attention(
@ -169,12 +180,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
) )
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype) ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else: else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in. # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4: if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection: if attn.residual_connection:
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
return hidden_states # casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states)

View File

@ -1,17 +1,25 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import List, Optional, TypedDict
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
CustomAttnProcessor2_0,
IPAdapterAttentionWeights,
)
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
class UNetAttentionPatcher: class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" """A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapters: Optional[list[IPAdapter]]): def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapters self._ip_adapters = ip_adapter_data
def _prepare_attention_processors(self, unet: UNet2DConditionModel): def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
@ -26,9 +34,22 @@ class UNetAttentionPatcher:
attn_procs[name] = CustomAttnProcessor2_0() attn_procs[name] = CustomAttnProcessor2_0()
else: else:
# Collect the weights from each IP Adapter for the idx'th attention processor. # Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0( ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
) )
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
return attn_procs return attn_procs
@contextmanager @contextmanager

View File

@ -213,6 +213,10 @@
"resize": "Resize", "resize": "Resize",
"resizeSimple": "Resize (Simple)", "resizeSimple": "Resize (Simple)",
"resizeMode": "Resize Mode", "resizeMode": "Resize Mode",
"ipAdapterMethod": "Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only",
"safe": "Safe", "safe": "Safe",
"saveControlImage": "Save Control Image", "saveControlImage": "Save Control Image",
"scribble": "scribble", "scribble": "scribble",

View File

@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports'; import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd'; import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode'; import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect'; import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode'; import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight'; import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
<Flex w="full" flexDir="column" gap={4}> <Flex w="full" flexDir="column" gap={4}>
<Flex gap={8} w="full" alignItems="center"> <Flex gap={8} w="full" alignItems="center">
<Flex flexDir="column" gap={2} h={32} w="full"> <Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
<ParamControlAdapterIPMethod id={id} />
<ParamControlAdapterWeight id={id} /> <ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} /> <ParamControlAdapterBeginEnd id={id} />
</Flex> </Flex>

View File

@ -0,0 +1,63 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPMethod } from 'features/controlAdapters/store/types';
import { isIPMethod } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const ParamControlAdapterIPMethod = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const method = useControlAdapterIPMethod(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const options: { label: string; value: IPMethod }[] = useMemo(
() => [
{ label: t('controlnet.full'), value: 'full' },
{ label: t('controlnet.style'), value: 'style' },
{ label: t('controlnet.composition'), value: 'composition' },
],
[t]
);
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
(v) => {
if (!isIPMethod(v?.value)) {
return;
}
dispatch(
controlAdapterIPMethodChanged({
id,
method: v.value,
})
);
},
[id, dispatch]
);
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
if (!method) {
return null;
}
return (
<FormControl>
<InformationalPopover feature="controlNetResizeMode">
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
</FormControl>
);
};
export default memo(ParamControlAdapterIPMethod);

View File

@ -0,0 +1,24 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterIPMethod = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.method;
}
}),
[id]
);
const method = useAppSelector(selector);
return method;
};

View File

@ -21,6 +21,7 @@ import type {
ControlAdapterType, ControlAdapterType,
ControlMode, ControlMode,
ControlNetConfig, ControlNetConfig,
IPMethod,
RequiredControlAdapterProcessorNode, RequiredControlAdapterProcessorNode,
ResizeMode, ResizeMode,
T2IAdapterConfig, T2IAdapterConfig,
@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
} }
caAdapter.updateOne(state, { id, changes: { controlMode } }); caAdapter.updateOne(state, { id, changes: { controlMode } });
}, },
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
const { id, method } = action.payload;
caAdapter.updateOne(state, { id, changes: { method } });
},
controlAdapterCLIPVisionModelChanged: ( controlAdapterCLIPVisionModelChanged: (
state, state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }> action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
@ -390,6 +395,7 @@ export const {
controlAdapterIsEnabledChanged, controlAdapterIsEnabledChanged,
controlAdapterModelChanged, controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged, controlAdapterCLIPVisionModelChanged,
controlAdapterIPMethodChanged,
controlAdapterWeightChanged, controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged, controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged, controlAdapterEndStepPctChanged,

View File

@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
export type ResizeMode = z.infer<typeof zResizeMode>; export type ResizeMode = z.infer<typeof zResizeMode>;
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success; export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
const zIPMethod = z.enum(['full', 'style', 'composition']);
export type IPMethod = z.infer<typeof zIPMethod>;
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
export type ControlNetConfig = { export type ControlNetConfig = {
type: 'controlnet'; type: 'controlnet';
id: string; id: string;
@ -253,6 +257,7 @@ export type IPAdapterConfig = {
model: ParameterIPAdapterModel | null; model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel; clipVisionModel: CLIPVisionModel;
weight: number; weight: number;
method: IPMethod;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
}; };

View File

@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true, isEnabled: true,
controlImage: null, controlImage: null,
model: null, model: null,
method: 'full',
clipVisionModel: 'ViT-H', clipVisionModel: 'ViT-H',
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,

View File

@ -386,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
.nullish() .nullish()
.catch(null) .catch(null)
.parse(await getProperty(metadataItem, 'weight')); .parse(await getProperty(metadataItem, 'weight'));
const method = zIPAdapterField.shape.method
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'method'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish() .nullish()
.catch(null) .catch(null)
@ -403,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
clipVisionModel: 'ViT-H', clipVisionModel: 'ViT-H',
controlImage: image?.image_name ?? null, controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight, weight: weight ?? initialIPAdapter.weight,
method: method ?? initialIPAdapter.method,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
}; };

View File

@ -109,6 +109,7 @@ export const zIPAdapterField = z.object({
image: zImageField, image: zImageField,
ip_adapter_model: zModelIdentifierField, ip_adapter_model: zModelIdentifierField,
weight: z.number(), weight: z.number(),
method: z.enum(['full', 'style', 'composition']),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),
}); });

View File

@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) { if (!ipAdapter.model) {
return; return;
} }
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter; const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required'); assert(controlImage, 'IP Adapter image is required');
@ -57,6 +57,7 @@ export const addIPAdapterToLinearGraph = async (
type: 'ip_adapter', type: 'ip_adapter',
is_intermediate: true, is_intermediate: true,
weight: weight, weight: weight,
method: method,
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel, clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
@ -84,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
}; };
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => { const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter; const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
@ -102,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel, clip_vision_model: clipVisionModel,
weight, weight,
method,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
image, image,

File diff suppressed because one or more lines are too long