mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: improve types in ip_adapter backend file
This commit is contained in:
parent
318bc938fe
commit
688a0f30bb
@ -53,7 +53,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds: torch.Tensor):
|
||||||
embeds = image_embeds
|
embeds = image_embeds
|
||||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||||
@ -95,7 +95,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds: torch.Tensor):
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
@ -137,7 +137,9 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
def _init_image_proj_model(
|
||||||
|
self, state_dict: dict[str, torch.Tensor]
|
||||||
|
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
|
||||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -152,7 +154,7 @@ class IPAdapter(RawModel):
|
|||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -196,36 +198,46 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
def build_ip_adapter(
|
def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapterStateDict:
|
||||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
|
||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
|
||||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
|
|
||||||
if ip_adapter_ckpt_path.endswith("safetensors"):
|
if ip_adapter_ckpt_path.endswith("safetensors"):
|
||||||
state_dict = {"ip_adapter": {}, "image_proj": {}}
|
model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt")
|
||||||
model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt")
|
|
||||||
for key in model.keys():
|
for key in model.keys():
|
||||||
if key.startswith("image_proj."):
|
if key.startswith("image_proj."):
|
||||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
|
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
|
||||||
if key.startswith("ip_adapter."):
|
elif key.startswith("ip_adapter."):
|
||||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
|
||||||
else:
|
else:
|
||||||
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
|
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
|
||||||
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def build_ip_adapter(
|
||||||
|
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||||
|
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
||||||
|
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
||||||
|
|
||||||
|
# IPAdapter (with ImageProjModel)
|
||||||
|
if "proj.weight" in state_dict["image_proj"]:
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
|
||||||
|
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
|
||||||
|
elif "proj_in.weight" in state_dict["image_proj"]:
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
if cross_attention_dim == 768:
|
if cross_attention_dim == 768:
|
||||||
# SD1 IP-Adapter Plus
|
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
|
||||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
|
||||||
elif cross_attention_dim == 2048:
|
elif cross_attention_dim == 2048:
|
||||||
# SDXL IP-Adapter Plus
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
|
||||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
|
||||||
|
# IPAdapterFull (with MLPProjModel)
|
||||||
|
elif "proj.0.weight" in state_dict["image_proj"]:
|
||||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Unrecognized IP Adapter Architectures
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||||
|
Loading…
Reference in New Issue
Block a user