diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 81514a9f8b..5444c76c8c 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -53,7 +53,7 @@ class ImageProjModel(torch.nn.Module): model.load_state_dict(state_dict) return model - def forward(self, image_embeds): + def forward(self, image_embeds: torch.Tensor): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim @@ -95,7 +95,7 @@ class MLPProjModel(torch.nn.Module): model.load_state_dict(state_dict) return model - def forward(self, image_embeds): + def forward(self, image_embeds: torch.Tensor): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -137,7 +137,9 @@ class IPAdapter(RawModel): return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) - def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]): + def _init_image_proj_model( + self, state_dict: dict[str, torch.Tensor] + ) -> Union[ImageProjModel, Resampler, MLPProjModel]: return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) @torch.inference_mode() @@ -152,7 +154,7 @@ class IPAdapter(RawModel): class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" - def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]): + def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]: return Resampler.from_state_dict( state_dict=state_dict, depth=4, @@ -196,36 +198,46 @@ class IPAdapterPlusXL(IPAdapterPlus): ).to(self.device, dtype=self.dtype) -def build_ip_adapter( - ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16 -) -> Union[IPAdapter, IPAdapterPlus]: +def load_ip_adapter_tensors(ip_adapter_ckpt_path: str, device: str) -> IPAdapterStateDict: state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}} if ip_adapter_ckpt_path.endswith("safetensors"): - state_dict = {"ip_adapter": {}, "image_proj": {}} - model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt") + model = safe_open(ip_adapter_ckpt_path, device=device, framework="pt") for key in model.keys(): if key.startswith("image_proj."): state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key) - if key.startswith("ip_adapter."): + elif key.startswith("ip_adapter."): state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key) else: ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin" state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu") - if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel). + return state_dict + + +def build_ip_adapter( + ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16 +) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]: + state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type) + + # IPAdapter (with ImageProjModel) + if "proj.weight" in state_dict["image_proj"]: return IPAdapter(state_dict, device=device, dtype=dtype) - elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler). + + # IPAdaterPlus or IPAdapterPlusXL (with Resampler) + elif "proj_in.weight" in state_dict["image_proj"]: cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] if cross_attention_dim == 768: - # SD1 IP-Adapter Plus - return IPAdapterPlus(state_dict, device=device, dtype=dtype) + return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus elif cross_attention_dim == 2048: - # SDXL IP-Adapter Plus - return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) + return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus else: raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") - elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel). + + # IPAdapterFull (with MLPProjModel) + elif "proj.0.weight" in state_dict["image_proj"]: return IPAdapterFull(state_dict, device=device, dtype=dtype) + + # Unrecognized IP Adapter Architectures else: raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")