chore: improve types in ip_adapter backend file

This commit is contained in:
blessedcoolant 2024-03-24 08:34:11 +05:30
parent 9ff729a7e6
commit 936b99bd3c

View File

@ -53,7 +53,7 @@ class ImageProjModel(torch.nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
def forward(self, image_embeds: torch.Tensor):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
@ -95,7 +95,7 @@ class MLPProjModel(torch.nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
def forward(self, image_embeds: torch.Tensor):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -137,7 +137,9 @@ class IPAdapter(RawModel):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
@torch.inference_mode()
@ -152,7 +154,7 @@ class IPAdapter(RawModel):
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
@ -196,36 +198,46 @@ class IPAdapterPlusXL(IPAdapterPlus):
).to(self.device, dtype=self.dtype)
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if ip_adapter_ckpt_path.endswith("safetensors"):
state_dict = {"ip_adapter": {}, "image_proj": {}}
model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt")
model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt")
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
if key.startswith("ip_adapter."):
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
return state_dict
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
# IPAdapter (with ImageProjModel)
if "proj.weight" in state_dict["image_proj"]:
return IPAdapter(state_dict, device=device, dtype=dtype)
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
elif "proj_in.weight" in state_dict["image_proj"]:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
# SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
elif cross_attention_dim == 2048:
# SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
# IPAdapterFull (with MLPProjModel)
elif "proj.0.weight" in state_dict["image_proj"]:
return IPAdapterFull(state_dict, device=device, dtype=dtype)
# Unrecognized IP Adapter Architectures
else:
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")