mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: improve types in ip_adapter backend file
This commit is contained in:
parent
9ff729a7e6
commit
936b99bd3c
@ -53,7 +53,7 @@ class ImageProjModel(torch.nn.Module):
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
def forward(self, image_embeds: torch.Tensor):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||
@ -95,7 +95,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
def forward(self, image_embeds: torch.Tensor):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
@ -137,7 +137,9 @@ class IPAdapter(RawModel):
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
def _init_image_proj_model(
|
||||
self, state_dict: dict[str, torch.Tensor]
|
||||
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -152,7 +154,7 @@ class IPAdapter(RawModel):
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
@ -196,36 +198,46 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapterStateDict:
|
||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||
|
||||
if ip_adapter_ckpt_path.endswith("safetensors"):
|
||||
state_dict = {"ip_adapter": {}, "image_proj": {}}
|
||||
model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt")
|
||||
model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt")
|
||||
for key in model.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
|
||||
if key.startswith("ip_adapter."):
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
|
||||
else:
|
||||
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
|
||||
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||
|
||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||
return state_dict
|
||||
|
||||
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
||||
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
||||
|
||||
# IPAdapter (with ImageProjModel)
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
||||
|
||||
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
|
||||
elif "proj_in.weight" in state_dict["image_proj"]:
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
# SD1 IP-Adapter Plus
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
|
||||
elif cross_attention_dim == 2048:
|
||||
# SDXL IP-Adapter Plus
|
||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
|
||||
else:
|
||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
||||
|
||||
# IPAdapterFull (with MLPProjModel)
|
||||
elif "proj.0.weight" in state_dict["image_proj"]:
|
||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||
|
||||
# Unrecognized IP Adapter Architectures
|
||||
else:
|
||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||
|
Loading…
Reference in New Issue
Block a user