Fixes to get IP-Adapter tests working with new multi-IP-Adapter support.

This commit is contained in:
Ryan Dick 2023-10-06 11:46:11 -04:00 committed by Kent Keirsey
parent 7ca456d674
commit 26b91a538a
4 changed files with 23 additions and 13 deletions

View File

@ -86,7 +86,7 @@ class IPAdapter:
self.attn_weights.to(device=self.device, dtype=self.dtype)
def calc_size(self):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self._attn_weights)
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)

View File

@ -22,25 +22,31 @@ class IPAttentionWeights(torch.nn.Module):
forward(...) method.
"""
def __init__(self, weights: dict[int, IPAttentionProcessorWeights]):
def __init__(self, weights: torch.nn.ModuleDict):
super().__init__()
self.weights = weights
self._weights = weights
def set_scale(self, scale: float):
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
for w in self.weights.values():
for w in self._weights.values():
w.scale = scale
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
# `torch.nn.ModuleDict` only supports str keys.
return self._weights[str(int(idx))]
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
attn_proc_weights: dict[int, IPAttentionProcessorWeights] = {}
attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {}
for tensor_name, tensor in state_dict.items():
if "to_k_ip.weight" in tensor_name:
index = int(tensor_name.split(".")[0])
index = str(int(tensor_name.split(".")[0]))
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
attn_proc_weights_module_dict = torch.nn.ModuleDict(attn_proc_weights)
attn_proc_weights_module_dict.load_state_dict(state_dict)
attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights)
attn_proc_weights_module.load_state_dict(state_dict)
return cls(attn_proc_weights)
return cls(attn_proc_weights_module)

View File

@ -31,11 +31,14 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0([ip_adapter.attn_weights.weights[idx] for ip_adapter in ip_adapters])
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters]
)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(cls, unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = _prepare_attention_processors(unet, ip_adapters)

View File

@ -1,6 +1,7 @@
import pytest
import torch
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.util.test_utils import install_and_load_model
@ -64,8 +65,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)}
with ip_adapter.apply_ip_adapter_attention(unet, 1.0):
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
with apply_ip_adapter_attention(unet, [ip_adapter]):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
assert output.shape == dummy_unet_input["sample"].shape