mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Initial implementation of safetensor support for IP Adapter
This commit is contained in:
@ -9,8 +9,8 @@ import torch.nn as nn
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
def FeedForward(dim: int, mult: int = 4):
|
||||
inner_dim = dim * mult
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
def reshape_tensor(x: torch.Tensor, heads: int):
|
||||
bs, length, _ = x.shape
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
def forward(self, x: torch.Tensor, latents: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
|
||||
class Resampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
dim: int = 1024,
|
||||
depth: int = 8,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
embedding_dim: int = 768,
|
||||
output_dim: int = 1024,
|
||||
ff_mult: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -110,7 +110,15 @@ class Resampler(nn.Module):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
||||
def from_state_dict(
|
||||
cls,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
depth: int = 8,
|
||||
dim_head: int = 64,
|
||||
heads: int = 16,
|
||||
num_queries: int = 8,
|
||||
ff_mult: int = 4,
|
||||
):
|
||||
"""A convenience function that initializes a Resampler from a state_dict.
|
||||
|
||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||
@ -145,7 +153,7 @@ class Resampler(nn.Module):
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
Reference in New Issue
Block a user