Infer the clip_extra_context_tokens param from the state dict for FLUX XLabs IP-Adapter V2 models.

This commit is contained in:
Ryan Dick
2024-11-15 20:58:30 +00:00
committed by psychedelicious
parent 9a77e951d2
commit c6fc82f756
3 changed files with 20 additions and 3 deletions

View File

@ -41,10 +41,12 @@ def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Te
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
clip_extra_context_tokens = state_dict["ip_adapter_proj_model.proj.weight"].shape[0] // context_dim
return XlabsIpAdapterParams(
num_double_blocks=num_double_blocks,
context_dim=context_dim,
hidden_dim=hidden_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=clip_extra_context_tokens,
)

View File

@ -31,13 +31,16 @@ class XlabsIpAdapterParams:
hidden_dim: int
clip_embeddings_dim: int
clip_extra_context_tokens: int
class XlabsIpAdapterFlux(torch.nn.Module):
def __init__(self, params: XlabsIpAdapterParams):
super().__init__()
self.image_proj = ImageProjModel(
cross_attention_dim=params.context_dim, clip_embeddings_dim=params.clip_embeddings_dim
cross_attention_dim=params.context_dim,
clip_embeddings_dim=params.clip_embeddings_dim,
clip_extra_context_tokens=params.clip_extra_context_tokens,
)
self.ip_adapter_double_blocks = IPAdapterDoubleBlocks(
num_double_blocks=params.num_double_blocks, context_dim=params.context_dim, hidden_dim=params.hidden_dim

View File

@ -30,11 +30,23 @@ def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]):
[
(
xlabs_flux_ip_adapter_sd_shapes,
XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768),
XlabsIpAdapterParams(
num_double_blocks=19,
context_dim=4096,
hidden_dim=3072,
clip_embeddings_dim=768,
clip_extra_context_tokens=4,
),
),
(
xlabs_flux_ip_adapter_v2_sd_shapes,
XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768),
XlabsIpAdapterParams(
num_double_blocks=19,
context_dim=4096,
hidden_dim=3072,
clip_embeddings_dim=768,
clip_extra_context_tokens=16,
),
),
],
)