mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes to get IP-Adapter tests working with new multi-IP-Adapter support.
This commit is contained in:
parent
7ca456d674
commit
26b91a538a
@ -86,7 +86,7 @@ class IPAdapter:
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self):
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self._attn_weights)
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
@ -22,25 +22,31 @@ class IPAttentionWeights(torch.nn.Module):
|
||||
forward(...) method.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: dict[int, IPAttentionProcessorWeights]):
|
||||
def __init__(self, weights: torch.nn.ModuleDict):
|
||||
super().__init__()
|
||||
self.weights = weights
|
||||
self._weights = weights
|
||||
|
||||
def set_scale(self, scale: float):
|
||||
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
|
||||
for w in self.weights.values():
|
||||
for w in self._weights.values():
|
||||
w.scale = scale
|
||||
|
||||
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
|
||||
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
|
||||
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
|
||||
# `torch.nn.ModuleDict` only supports str keys.
|
||||
return self._weights[str(int(idx))]
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
attn_proc_weights: dict[int, IPAttentionProcessorWeights] = {}
|
||||
attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {}
|
||||
|
||||
for tensor_name, tensor in state_dict.items():
|
||||
if "to_k_ip.weight" in tensor_name:
|
||||
index = int(tensor_name.split(".")[0])
|
||||
index = str(int(tensor_name.split(".")[0]))
|
||||
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
|
||||
|
||||
attn_proc_weights_module_dict = torch.nn.ModuleDict(attn_proc_weights)
|
||||
attn_proc_weights_module_dict.load_state_dict(state_dict)
|
||||
attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights)
|
||||
attn_proc_weights_module.load_state_dict(state_dict)
|
||||
|
||||
return cls(attn_proc_weights)
|
||||
return cls(attn_proc_weights_module)
|
||||
|
@ -31,11 +31,14 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = IPAttnProcessor2_0([ip_adapter.attn_weights.weights[idx] for ip_adapter in ip_adapters])
|
||||
attn_procs[name] = IPAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters]
|
||||
)
|
||||
return attn_procs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(cls, unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
||||
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
|
||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
||||
attn_procs = _prepare_attention_processors(unet, ip_adapters)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
@ -64,8 +65,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
||||
ip_adapter.to(torch_device, dtype=torch.float32)
|
||||
unet.to(torch_device, dtype=torch.float32)
|
||||
|
||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": torch.randn((1, 4, 768)).to(torch_device)}
|
||||
with ip_adapter.apply_ip_adapter_attention(unet, 1.0):
|
||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
|
||||
with apply_ip_adapter_attention(unet, [ip_adapter]):
|
||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
assert output.shape == dummy_unet_input["sample"].shape
|
||||
|
Loading…
Reference in New Issue
Block a user