diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 0f21cb9842..826112156d 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -86,7 +86,7 @@ class IPAdapter: self.attn_weights.to(device=self.device, dtype=self.dtype) def calc_size(self): - return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self._attn_weights) + return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) def _init_image_proj_model(self, state_dict): return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) diff --git a/invokeai/backend/ip_adapter/ip_attention_weights.py b/invokeai/backend/ip_adapter/ip_attention_weights.py index 6eb1630890..e7ed9e9c76 100644 --- a/invokeai/backend/ip_adapter/ip_attention_weights.py +++ b/invokeai/backend/ip_adapter/ip_attention_weights.py @@ -22,25 +22,31 @@ class IPAttentionWeights(torch.nn.Module): forward(...) method. """ - def __init__(self, weights: dict[int, IPAttentionProcessorWeights]): + def __init__(self, weights: torch.nn.ModuleDict): super().__init__() - self.weights = weights + self._weights = weights def set_scale(self, scale: float): """Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection.""" - for w in self.weights.values(): + for w in self._weights.values(): w.scale = scale + def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights: + """Get the `IPAttentionProcessorWeights` for the idx'th attention processor.""" + # Cast to int first, because we expect the key to represent an int. Then cast back to str, because + # `torch.nn.ModuleDict` only supports str keys. + return self._weights[str(int(idx))] + @classmethod def from_state_dict(cls, state_dict: dict[str, torch.Tensor]): - attn_proc_weights: dict[int, IPAttentionProcessorWeights] = {} + attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {} for tensor_name, tensor in state_dict.items(): if "to_k_ip.weight" in tensor_name: - index = int(tensor_name.split(".")[0]) + index = str(int(tensor_name.split(".")[0])) attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0]) - attn_proc_weights_module_dict = torch.nn.ModuleDict(attn_proc_weights) - attn_proc_weights_module_dict.load_state_dict(state_dict) + attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights) + attn_proc_weights_module.load_state_dict(state_dict) - return cls(attn_proc_weights) + return cls(attn_proc_weights_module) diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/ip_adapter/unet_patcher.py index 17f3c4e595..ac9cf6de83 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/ip_adapter/unet_patcher.py @@ -31,11 +31,14 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[ attn_procs[name] = AttnProcessor2_0() else: # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = IPAttnProcessor2_0([ip_adapter.attn_weights.weights[idx] for ip_adapter in ip_adapters]) + attn_procs[name] = IPAttnProcessor2_0( + [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters] + ) + return attn_procs @contextmanager -def apply_ip_adapter_attention(cls, unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]): +def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]): """A context manager that patches `unet` with IP-Adapter attention processors.""" attn_procs = _prepare_attention_processors(unet, ip_adapters) diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 1248ead98b..f2ca243a93 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -1,6 +1,7 @@ import pytest import torch +from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.util.test_utils import install_and_load_model @@ -64,8 +65,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): ip_adapter.to(torch_device, dtype=torch.float32) unet.to(torch_device, dtype=torch.float32) - cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)} - with ip_adapter.apply_ip_adapter_attention(unet, 1.0): + cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]} + with apply_ip_adapter_attention(unet, [ip_adapter]): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample assert output.shape == dummy_unet_input["sample"].shape