mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes to get IP-Adapter tests working with new multi-IP-Adapter support.
This commit is contained in:
parent
7ca456d674
commit
26b91a538a
@ -86,7 +86,7 @@ class IPAdapter:
|
|||||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def calc_size(self):
|
def calc_size(self):
|
||||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self._attn_weights)
|
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict):
|
||||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||||
|
@ -22,25 +22,31 @@ class IPAttentionWeights(torch.nn.Module):
|
|||||||
forward(...) method.
|
forward(...) method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weights: dict[int, IPAttentionProcessorWeights]):
|
def __init__(self, weights: torch.nn.ModuleDict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weights = weights
|
self._weights = weights
|
||||||
|
|
||||||
def set_scale(self, scale: float):
|
def set_scale(self, scale: float):
|
||||||
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
|
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
|
||||||
for w in self.weights.values():
|
for w in self._weights.values():
|
||||||
w.scale = scale
|
w.scale = scale
|
||||||
|
|
||||||
|
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
|
||||||
|
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
|
||||||
|
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
|
||||||
|
# `torch.nn.ModuleDict` only supports str keys.
|
||||||
|
return self._weights[str(int(idx))]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||||
attn_proc_weights: dict[int, IPAttentionProcessorWeights] = {}
|
attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {}
|
||||||
|
|
||||||
for tensor_name, tensor in state_dict.items():
|
for tensor_name, tensor in state_dict.items():
|
||||||
if "to_k_ip.weight" in tensor_name:
|
if "to_k_ip.weight" in tensor_name:
|
||||||
index = int(tensor_name.split(".")[0])
|
index = str(int(tensor_name.split(".")[0]))
|
||||||
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
|
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
|
||||||
|
|
||||||
attn_proc_weights_module_dict = torch.nn.ModuleDict(attn_proc_weights)
|
attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights)
|
||||||
attn_proc_weights_module_dict.load_state_dict(state_dict)
|
attn_proc_weights_module.load_state_dict(state_dict)
|
||||||
|
|
||||||
return cls(attn_proc_weights)
|
return cls(attn_proc_weights_module)
|
||||||
|
@ -31,11 +31,14 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
|
|||||||
attn_procs[name] = AttnProcessor2_0()
|
attn_procs[name] = AttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = IPAttnProcessor2_0([ip_adapter.attn_weights.weights[idx] for ip_adapter in ip_adapters])
|
attn_procs[name] = IPAttnProcessor2_0(
|
||||||
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters]
|
||||||
|
)
|
||||||
|
return attn_procs
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ip_adapter_attention(cls, unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
||||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
||||||
attn_procs = _prepare_attention_processors(unet, ip_adapters)
|
attn_procs = _prepare_attention_processors(unet, ip_adapters)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.util.test_utils import install_and_load_model
|
from invokeai.backend.util.test_utils import install_and_load_model
|
||||||
|
|
||||||
@ -64,8 +65,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
|||||||
ip_adapter.to(torch_device, dtype=torch.float32)
|
ip_adapter.to(torch_device, dtype=torch.float32)
|
||||||
unet.to(torch_device, dtype=torch.float32)
|
unet.to(torch_device, dtype=torch.float32)
|
||||||
|
|
||||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)}
|
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
|
||||||
with ip_adapter.apply_ip_adapter_attention(unet, 1.0):
|
with apply_ip_adapter_attention(unet, [ip_adapter]):
|
||||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||||
|
|
||||||
assert output.shape == dummy_unet_input["sample"].shape
|
assert output.shape == dummy_unet_input["sample"].shape
|
||||||
|
Loading…
Reference in New Issue
Block a user