import torch class IPAttentionProcessorWeights(torch.nn.Module): """The IP-Adapter weights for a single attention processor. This class is a torch.nn.Module sub-class to facilitate loading from a state_dict. It does not have a forward(...) method. """ def __init__(self, in_dim: int, out_dim: int): super().__init__() self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False) self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False) class IPAttentionWeights(torch.nn.Module): """A collection of all the `IPAttentionProcessorWeights` objects for an IP-Adapter model. This class is a torch.nn.Module sub-class so that it inherits the `.to(...)` functionality. It does not have a forward(...) method. """ def __init__(self, weights: torch.nn.ModuleDict): super().__init__() self._weights = weights def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights: """Get the `IPAttentionProcessorWeights` for the idx'th attention processor.""" # Cast to int first, because we expect the key to represent an int. Then cast back to str, because # `torch.nn.ModuleDict` only supports str keys. return self._weights[str(int(idx))] @classmethod def from_state_dict(cls, state_dict: dict[str, torch.Tensor]): attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {} for tensor_name, tensor in state_dict.items(): if "to_k_ip.weight" in tensor_name: index = str(int(tensor_name.split(".")[0])) attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0]) attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights) attn_proc_weights_module.load_state_dict(state_dict) return cls(attn_proc_weights_module)