Remove need for the image_encoder param in IPAdapter.initialize().

This commit is contained in:
Ryan Dick
2023-09-14 14:14:35 -04:00
parent cc8b7a74da
commit d114d0ba95
3 changed files with 70 additions and 20 deletions

View File

@ -109,6 +109,42 @@ class Resampler(nn.Module):
)
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
possible if needed in the future.
Args:
state_dict (dict[torch.Tensor]): The state_dict to load.
depth (int, optional):
dim_head (int, optional):
heads (int, optional):
ff_mult (int, optional):
Returns:
Resampler
"""
dim = state_dict["latents"].shape[2]
num_queries = state_dict["latents"].shape[1]
embedding_dim = state_dict["proj_in.weight"].shape[-1]
output_dim = state_dict["norm_out.weight"].shape[0]
model = cls(
dim=dim,
depth=depth,
dim_head=dim_head,
heads=heads,
num_queries=num_queries,
embedding_dim=embedding_dim,
output_dim=output_dim,
ff_mult=ff_mult,
)
model.load_state_dict(state_dict)
return model
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)