mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Initial Implementation IP Adapter Style & Comp Modes
This commit is contained in:
parent
24f2cde862
commit
6ea183f0d4
@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import (
|
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
OutputField,
|
|
||||||
TensorField,
|
|
||||||
UIType,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
|
|||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||||
|
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
)
|
)
|
||||||
@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
weight: Union[float, List[float]] = InputField(
|
weight: Union[float, List[float]] = InputField(
|
||||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||||
)
|
)
|
||||||
|
method: Literal["full", "style", "composition"] = InputField(
|
||||||
|
default="full", description="The method to apply the IP-Adapter"
|
||||||
|
)
|
||||||
begin_step_percent: float = InputField(
|
begin_step_percent: float = InputField(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
)
|
)
|
||||||
@ -124,12 +116,19 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
|
|
||||||
|
target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]
|
||||||
|
if self.method == "style":
|
||||||
|
target_blocks = ["up_blocks.0.attentions.1"]
|
||||||
|
elif self.method == "composition":
|
||||||
|
target_blocks = ["down_blocks.2.attentions.1"]
|
||||||
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
ip_adapter_model=self.ip_adapter_model,
|
ip_adapter_model=self.ip_adapter_model,
|
||||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
|
target_blocks=target_blocks,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
mask=self.mask,
|
mask=self.mask,
|
||||||
|
@ -15,12 +15,10 @@ from diffusers import AutoencoderKL, AutoencoderTiny
|
|||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.models.adapter import T2IAdapter
|
from diffusers.models.adapter import T2IAdapter
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (AttnProcessor2_0,
|
||||||
AttnProcessor2_0,
|
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor)
|
||||||
)
|
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
@ -29,22 +27,17 @@ from pydantic import field_validator
|
|||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPVisionModelWithProjection
|
from transformers import CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import (LATENT_SCALE_FACTOR,
|
||||||
from invokeai.app.invocations.fields import (
|
SCHEDULER_NAME_VALUES)
|
||||||
ConditioningField,
|
from invokeai.app.invocations.fields import (ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
FieldDescriptions,
|
FieldDescriptions, ImageField,
|
||||||
ImageField,
|
Input, InputField, LatentsField,
|
||||||
Input,
|
OutputField, UIType, WithBoard,
|
||||||
InputField,
|
WithMetadata)
|
||||||
LatentsField,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
WithBoard,
|
|
||||||
WithMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
|
from invokeai.app.invocations.primitives import (DenoiseMaskOutput,
|
||||||
|
ImageOutput, LatentsOutput)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
@ -52,28 +45,21 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
|||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import (PipelineIntermediateState,
|
||||||
|
set_seamless)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo, IPAdapterConditioningInfo, IPAdapterData, Range,
|
||||||
IPAdapterConditioningInfo,
|
SDXLConditioningInfo, TextConditioningData, TextConditioningRegions)
|
||||||
IPAdapterData,
|
|
||||||
Range,
|
|
||||||
SDXLConditioningInfo,
|
|
||||||
TextConditioningData,
|
|
||||||
TextConditioningRegions,
|
|
||||||
)
|
|
||||||
from invokeai.backend.util.mask import to_standard_float_mask
|
from invokeai.backend.util.mask import to_standard_float_mask
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData, StableDiffusionGeneratorPipeline, T2IAdapterData,
|
||||||
StableDiffusionGeneratorPipeline,
|
image_resized_to_grid_as_tensor)
|
||||||
T2IAdapterData,
|
|
||||||
image_resized_to_grid_as_tensor,
|
|
||||||
)
|
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, invocation,
|
||||||
|
invocation_output)
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelIdentifierField, UNetField, VAEField
|
||||||
|
|
||||||
@ -682,6 +668,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
IPAdapterData(
|
IPAdapterData(
|
||||||
ip_adapter_model=ip_adapter_model,
|
ip_adapter_model=ip_adapter_model,
|
||||||
weight=single_ip_adapter.weight,
|
weight=single_ip_adapter.weight,
|
||||||
|
target_blocks=single_ip_adapter.target_blocks,
|
||||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
end_step_percent=single_ip_adapter.end_step_percent,
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
|
@ -21,12 +21,9 @@ from pydantic import Field
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
IPAdapterData,
|
|
||||||
TextConditioningData,
|
|
||||||
)
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import normalize_device
|
from invokeai.backend.util.devices import normalize_device
|
||||||
|
|
||||||
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
unet_attention_patcher = None
|
unet_attention_patcher = None
|
||||||
self.use_ip_adapter = use_ip_adapter
|
self.use_ip_adapter = use_ip_adapter
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
if use_ip_adapter or use_regional_prompting:
|
||||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||||
|
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||||
|
if use_ip_adapter
|
||||||
|
else None
|
||||||
|
)
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ class IPAdapterData:
|
|||||||
ip_adapter_model: IPAdapter
|
ip_adapter_model: IPAdapter
|
||||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||||
mask: torch.Tensor
|
mask: torch.Tensor
|
||||||
|
target_blocks: List[str]
|
||||||
|
|
||||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||||
weight: Union[float, List[float]] = 1.0
|
weight: Union[float, List[float]] = 1.0
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import List, Optional, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -9,6 +9,11 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterAttentionWeights(TypedDict):
|
||||||
|
ip_adapter_weights: List[IPAttentionProcessorWeights]
|
||||||
|
skip: bool
|
||||||
|
|
||||||
|
|
||||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||||
This implementation is based on
|
This implementation is based on
|
||||||
@ -20,7 +25,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
ip_adapter_attention_weights: Optional[IPAdapterAttentionWeights] = None,
|
||||||
):
|
):
|
||||||
"""Initialize a CustomAttnProcessor2_0.
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
@ -30,10 +35,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
for the i'th IP-Adapter.
|
for the i'th IP-Adapter.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._ip_adapter_weights = ip_adapter_weights
|
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||||
|
|
||||||
def _is_ip_adapter_enabled(self) -> bool:
|
|
||||||
return self._ip_adapter_weights is not None
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -130,17 +132,17 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Apply IP-Adapter conditioning.
|
# Apply IP-Adapter conditioning.
|
||||||
if is_cross_attention:
|
if is_cross_attention:
|
||||||
if self._is_ip_adapter_enabled():
|
if self._ip_adapter_attention_weights:
|
||||||
assert regional_ip_data is not None
|
assert regional_ip_data is not None
|
||||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
assert (
|
assert (
|
||||||
len(regional_ip_data.image_prompt_embeds)
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
== len(self._ip_adapter_weights)
|
== len(self._ip_adapter_attention_weights["ip_adapter_weights"])
|
||||||
== len(regional_ip_data.scales)
|
== len(regional_ip_data.scales)
|
||||||
== ip_masks.shape[1]
|
== ip_masks.shape[1]
|
||||||
)
|
)
|
||||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
ipa_weights = self._ip_adapter_weights[ipa_index]
|
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index]
|
||||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
ip_mask = ip_masks[0, ipa_index, ...]
|
ip_mask = ip_masks[0, ipa_index, ...]
|
||||||
|
|
||||||
@ -153,6 +155,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
|
if self._ip_adapter_attention_weights["skip"]:
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
@ -170,7 +174,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim
|
||||||
|
)
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional
|
from typing import List, Optional, TypedDict
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||||
|
CustomAttnProcessor2_0,
|
||||||
|
IPAdapterAttentionWeights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetIPAdapterData(TypedDict):
|
||||||
|
ip_adapter: IPAdapter
|
||||||
|
target_blocks: List[str]
|
||||||
|
|
||||||
|
|
||||||
class UNetAttentionPatcher:
|
class UNetAttentionPatcher:
|
||||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapter_data
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
@ -25,10 +33,23 @@ class UNetAttentionPatcher:
|
|||||||
# "attn1" processors do not use IP-Adapters.
|
# "attn1" processors do not use IP-Adapters.
|
||||||
attn_procs[name] = CustomAttnProcessor2_0()
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
|
||||||
|
for ip_adapter in self._ip_adapters:
|
||||||
|
|
||||||
|
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||||
|
skip = False
|
||||||
|
for block in ip_adapter["target_blocks"]:
|
||||||
|
if block in name:
|
||||||
|
skip = True
|
||||||
|
break
|
||||||
|
|
||||||
|
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})
|
||||||
|
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = CustomAttnProcessor2_0(
|
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights)
|
||||||
)
|
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -57,6 +57,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
|
method: 'composition',
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
clip_vision_model: clipVisionModel,
|
clip_vision_model: clipVisionModel,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user