mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'lstein/feat/multi-gpu' of github.com:invoke-ai/InvokeAI into lstein/feat/multi-gpu
This commit is contained in:
commit
89f8326c0b
@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import (
|
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
OutputField,
|
|
||||||
TensorField,
|
|
||||||
UIType,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
|
|||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||||
|
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
)
|
)
|
||||||
@ -69,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
weight: Union[float, List[float]] = InputField(
|
weight: Union[float, List[float]] = InputField(
|
||||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||||
)
|
)
|
||||||
|
method: Literal["full", "style", "composition"] = InputField(
|
||||||
|
default="full", description="The method to apply the IP-Adapter"
|
||||||
|
)
|
||||||
begin_step_percent: float = InputField(
|
begin_step_percent: float = InputField(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
)
|
)
|
||||||
@ -124,12 +116,32 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
|
|
||||||
|
if self.method == "style":
|
||||||
|
if ip_adapter_info.base == "sd-1":
|
||||||
|
target_blocks = ["up_blocks.1"]
|
||||||
|
elif ip_adapter_info.base == "sdxl":
|
||||||
|
target_blocks = ["up_blocks.0.attentions.1"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||||
|
elif self.method == "composition":
|
||||||
|
if ip_adapter_info.base == "sd-1":
|
||||||
|
target_blocks = ["down_blocks.2", "mid_block"]
|
||||||
|
elif ip_adapter_info.base == "sdxl":
|
||||||
|
target_blocks = ["down_blocks.2.attentions.1"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||||
|
elif self.method == "full":
|
||||||
|
target_blocks = ["block"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
|
||||||
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
ip_adapter_model=self.ip_adapter_model,
|
ip_adapter_model=self.ip_adapter_model,
|
||||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
|
target_blocks=target_blocks,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
mask=self.mask,
|
mask=self.mask,
|
||||||
|
@ -679,6 +679,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
IPAdapterData(
|
IPAdapterData(
|
||||||
ip_adapter_model=ip_adapter_model,
|
ip_adapter_model=ip_adapter_model,
|
||||||
weight=single_ip_adapter.weight,
|
weight=single_ip_adapter.weight,
|
||||||
|
target_blocks=single_ip_adapter.target_blocks,
|
||||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
end_step_percent=single_ip_adapter.end_step_percent,
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
|
@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||||
|
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
@ -754,6 +754,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||||
install_job.download_parts.add(download_job)
|
install_job.download_parts.add(download_job)
|
||||||
|
|
||||||
|
# only start the jobs once install_job.download_parts is fully populated
|
||||||
|
for download_job in install_job.download_parts:
|
||||||
self._download_queue.submit_download_job(
|
self._download_queue.submit_download_job(
|
||||||
download_job,
|
download_job,
|
||||||
on_start=self._download_started_callback,
|
on_start=self._download_started_callback,
|
||||||
@ -762,6 +764,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
on_error=self._download_error_callback,
|
on_error=self._download_error_callback,
|
||||||
on_cancelled=self._download_cancelled_callback,
|
on_cancelled=self._download_cancelled_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def _stat_size(self, path: Path) -> int:
|
def _stat_size(self, path: Path) -> int:
|
||||||
|
@ -21,12 +21,9 @@ from pydantic import Field
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
IPAdapterData,
|
|
||||||
TextConditioningData,
|
|
||||||
)
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
unet_attention_patcher = None
|
unet_attention_patcher = None
|
||||||
self.use_ip_adapter = use_ip_adapter
|
self.use_ip_adapter = use_ip_adapter
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
if use_ip_adapter or use_regional_prompting:
|
||||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||||
|
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||||
|
if use_ip_adapter
|
||||||
|
else None
|
||||||
|
)
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ class IPAdapterData:
|
|||||||
ip_adapter_model: IPAdapter
|
ip_adapter_model: IPAdapter
|
||||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||||
mask: torch.Tensor
|
mask: torch.Tensor
|
||||||
|
target_blocks: List[str]
|
||||||
|
|
||||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||||
weight: Union[float, List[float]] = 1.0
|
weight: Union[float, List[float]] = 1.0
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IPAdapterAttentionWeights:
|
||||||
|
ip_adapter_weights: IPAttentionProcessorWeights
|
||||||
|
skip: bool
|
||||||
|
|
||||||
|
|
||||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||||
This implementation is based on
|
This implementation is based on
|
||||||
@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize a CustomAttnProcessor2_0.
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
for the i'th IP-Adapter.
|
for the i'th IP-Adapter.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._ip_adapter_weights = ip_adapter_weights
|
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||||
|
|
||||||
def _is_ip_adapter_enabled(self) -> bool:
|
|
||||||
return self._ip_adapter_weights is not None
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
temb: Optional[torch.Tensor] = None,
|
||||||
# For regional prompting:
|
# For Regional Prompting:
|
||||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
percent_through: Optional[torch.FloatTensor] = None,
|
percent_through: Optional[torch.Tensor] = None,
|
||||||
# For IP-Adapter:
|
# For IP-Adapter:
|
||||||
regional_ip_data: Optional[RegionalIPData] = None,
|
regional_ip_data: Optional[RegionalIPData] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
"""Apply attention.
|
"""Apply attention.
|
||||||
Args:
|
Args:
|
||||||
@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Apply IP-Adapter conditioning.
|
# Apply IP-Adapter conditioning.
|
||||||
if is_cross_attention:
|
if is_cross_attention:
|
||||||
if self._is_ip_adapter_enabled():
|
if self._ip_adapter_attention_weights:
|
||||||
assert regional_ip_data is not None
|
assert regional_ip_data is not None
|
||||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(regional_ip_data.image_prompt_embeds)
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
== len(self._ip_adapter_weights)
|
== len(self._ip_adapter_attention_weights)
|
||||||
== len(regional_ip_data.scales)
|
== len(regional_ip_data.scales)
|
||||||
== ip_masks.shape[1]
|
== ip_masks.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
ipa_weights = self._ip_adapter_weights[ipa_index]
|
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
ip_mask = ip_masks[0, ipa_index, ...]
|
ip_mask = ip_masks[0, ipa_index, ...]
|
||||||
|
|
||||||
@ -153,15 +161,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
|
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
# Expected ip_key and ip_value shape:
|
||||||
|
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
# Expected ip_key and ip_value shape:
|
||||||
|
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
@ -169,12 +180,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim
|
||||||
|
)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||||
else:
|
else:
|
||||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||||
@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
hidden_states = attn.to_out[1](hidden_states)
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
if input_ndim == 4:
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
if attn.residual_connection:
|
if attn.residual_connection:
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End of unmodified block from AttnProcessor2_0
|
||||||
|
|
||||||
return hidden_states
|
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||||
|
return cast(torch.FloatTensor, hidden_states)
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional
|
from typing import List, Optional, TypedDict
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||||
|
CustomAttnProcessor2_0,
|
||||||
|
IPAdapterAttentionWeights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetIPAdapterData(TypedDict):
|
||||||
|
ip_adapter: IPAdapter
|
||||||
|
target_blocks: List[str]
|
||||||
|
|
||||||
|
|
||||||
class UNetAttentionPatcher:
|
class UNetAttentionPatcher:
|
||||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapter_data
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
@ -26,9 +34,22 @@ class UNetAttentionPatcher:
|
|||||||
attn_procs[name] = CustomAttnProcessor2_0()
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = CustomAttnProcessor2_0(
|
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
|
||||||
|
for ip_adapter in self._ip_adapters:
|
||||||
|
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||||
|
skip = True
|
||||||
|
for block in ip_adapter["target_blocks"]:
|
||||||
|
if block in name:
|
||||||
|
skip = False
|
||||||
|
break
|
||||||
|
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||||
|
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||||
)
|
)
|
||||||
|
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||||
|
|
||||||
|
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||||
|
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -213,6 +213,10 @@
|
|||||||
"resize": "Resize",
|
"resize": "Resize",
|
||||||
"resizeSimple": "Resize (Simple)",
|
"resizeSimple": "Resize (Simple)",
|
||||||
"resizeMode": "Resize Mode",
|
"resizeMode": "Resize Mode",
|
||||||
|
"ipAdapterMethod": "Method",
|
||||||
|
"full": "Full",
|
||||||
|
"style": "Style Only",
|
||||||
|
"composition": "Composition Only",
|
||||||
"safe": "Safe",
|
"safe": "Safe",
|
||||||
"saveControlImage": "Save Control Image",
|
"saveControlImage": "Save Control Image",
|
||||||
"scribble": "scribble",
|
"scribble": "scribble",
|
||||||
|
@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
|
|||||||
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||||
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
||||||
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
||||||
|
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
|
||||||
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
||||||
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
||||||
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
||||||
@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
|
|||||||
|
|
||||||
<Flex w="full" flexDir="column" gap={4}>
|
<Flex w="full" flexDir="column" gap={4}>
|
||||||
<Flex gap={8} w="full" alignItems="center">
|
<Flex gap={8} w="full" alignItems="center">
|
||||||
<Flex flexDir="column" gap={2} h={32} w="full">
|
<Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
|
||||||
|
<ParamControlAdapterIPMethod id={id} />
|
||||||
<ParamControlAdapterWeight id={id} />
|
<ParamControlAdapterWeight id={id} />
|
||||||
<ParamControlAdapterBeginEnd id={id} />
|
<ParamControlAdapterBeginEnd id={id} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -0,0 +1,63 @@
|
|||||||
|
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
|
||||||
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
|
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import type { IPMethod } from 'features/controlAdapters/store/types';
|
||||||
|
import { isIPMethod } from 'features/controlAdapters/store/types';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
id: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamControlAdapterIPMethod = ({ id }: Props) => {
|
||||||
|
const isEnabled = useControlAdapterIsEnabled(id);
|
||||||
|
const method = useControlAdapterIPMethod(id);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const options: { label: string; value: IPMethod }[] = useMemo(
|
||||||
|
() => [
|
||||||
|
{ label: t('controlnet.full'), value: 'full' },
|
||||||
|
{ label: t('controlnet.style'), value: 'style' },
|
||||||
|
{ label: t('controlnet.composition'), value: 'composition' },
|
||||||
|
],
|
||||||
|
[t]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
if (!isIPMethod(v?.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
controlAdapterIPMethodChanged({
|
||||||
|
id,
|
||||||
|
method: v.value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[id, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
|
||||||
|
|
||||||
|
if (!method) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl>
|
||||||
|
<InformationalPopover feature="controlNetResizeMode">
|
||||||
|
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlAdapterIPMethod);
|
@ -0,0 +1,24 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
selectControlAdapterById,
|
||||||
|
selectControlAdaptersSlice,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
|
export const useControlAdapterIPMethod = (id: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||||
|
const cn = selectControlAdapterById(controlAdapters, id);
|
||||||
|
if (cn && cn?.type === 'ip_adapter') {
|
||||||
|
return cn.method;
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
[id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const method = useAppSelector(selector);
|
||||||
|
|
||||||
|
return method;
|
||||||
|
};
|
@ -21,6 +21,7 @@ import type {
|
|||||||
ControlAdapterType,
|
ControlAdapterType,
|
||||||
ControlMode,
|
ControlMode,
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
|
IPMethod,
|
||||||
RequiredControlAdapterProcessorNode,
|
RequiredControlAdapterProcessorNode,
|
||||||
ResizeMode,
|
ResizeMode,
|
||||||
T2IAdapterConfig,
|
T2IAdapterConfig,
|
||||||
@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
}
|
}
|
||||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||||
},
|
},
|
||||||
|
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
|
||||||
|
const { id, method } = action.payload;
|
||||||
|
caAdapter.updateOne(state, { id, changes: { method } });
|
||||||
|
},
|
||||||
controlAdapterCLIPVisionModelChanged: (
|
controlAdapterCLIPVisionModelChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||||
@ -390,6 +395,7 @@ export const {
|
|||||||
controlAdapterIsEnabledChanged,
|
controlAdapterIsEnabledChanged,
|
||||||
controlAdapterModelChanged,
|
controlAdapterModelChanged,
|
||||||
controlAdapterCLIPVisionModelChanged,
|
controlAdapterCLIPVisionModelChanged,
|
||||||
|
controlAdapterIPMethodChanged,
|
||||||
controlAdapterWeightChanged,
|
controlAdapterWeightChanged,
|
||||||
controlAdapterBeginStepPctChanged,
|
controlAdapterBeginStepPctChanged,
|
||||||
controlAdapterEndStepPctChanged,
|
controlAdapterEndStepPctChanged,
|
||||||
|
@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
|
|||||||
export type ResizeMode = z.infer<typeof zResizeMode>;
|
export type ResizeMode = z.infer<typeof zResizeMode>;
|
||||||
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
||||||
|
|
||||||
|
const zIPMethod = z.enum(['full', 'style', 'composition']);
|
||||||
|
export type IPMethod = z.infer<typeof zIPMethod>;
|
||||||
|
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
|
||||||
|
|
||||||
export type ControlNetConfig = {
|
export type ControlNetConfig = {
|
||||||
type: 'controlnet';
|
type: 'controlnet';
|
||||||
id: string;
|
id: string;
|
||||||
@ -253,6 +257,7 @@ export type IPAdapterConfig = {
|
|||||||
model: ParameterIPAdapterModel | null;
|
model: ParameterIPAdapterModel | null;
|
||||||
clipVisionModel: CLIPVisionModel;
|
clipVisionModel: CLIPVisionModel;
|
||||||
weight: number;
|
weight: number;
|
||||||
|
method: IPMethod;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
};
|
};
|
||||||
|
@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
|||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
controlImage: null,
|
controlImage: null,
|
||||||
model: null,
|
model: null,
|
||||||
|
method: 'full',
|
||||||
clipVisionModel: 'ViT-H',
|
clipVisionModel: 'ViT-H',
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
|
@ -386,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(await getProperty(metadataItem, 'weight'));
|
.parse(await getProperty(metadataItem, 'weight'));
|
||||||
|
const method = zIPAdapterField.shape.method
|
||||||
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'method'));
|
||||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
@ -403,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
clipVisionModel: 'ViT-H',
|
clipVisionModel: 'ViT-H',
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
|
method: method ?? initialIPAdapter.method,
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||||
};
|
};
|
||||||
|
@ -109,6 +109,7 @@ export const zIPAdapterField = z.object({
|
|||||||
image: zImageField,
|
image: zImageField,
|
||||||
ip_adapter_model: zModelIdentifierField,
|
ip_adapter_model: zModelIdentifierField,
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
|
method: z.enum(['full', 'style', 'composition']),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
});
|
});
|
||||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||||
|
|
||||||
assert(controlImage, 'IP Adapter image is required');
|
assert(controlImage, 'IP Adapter image is required');
|
||||||
|
|
||||||
@ -57,6 +57,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
|
method: method,
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
clip_vision_model: clipVisionModel,
|
clip_vision_model: clipVisionModel,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
@ -84,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
|
||||||
|
|
||||||
assert(model, 'IP Adapter model is required');
|
assert(model, 'IP Adapter model is required');
|
||||||
|
|
||||||
@ -102,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
|||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
clip_vision_model: clipVisionModel,
|
clip_vision_model: clipVisionModel,
|
||||||
weight,
|
weight,
|
||||||
|
method,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
image,
|
image,
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user