mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Initial implementation of safetensor support for IP Adapter
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
@ -13,10 +14,17 @@ from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
class IPAdapterStateDict(TypedDict):
|
||||
ip_adapter: dict[str, torch.Tensor]
|
||||
image_proj: dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
"""Image Projection Model"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||
def __init__(
|
||||
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
@ -25,7 +33,7 @@ class ImageProjModel(torch.nn.Module):
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
|
||||
"""Initialize an ImageProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
@ -57,7 +65,7 @@ class ImageProjModel(torch.nn.Module):
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
"""SD model with image prompt"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
||||
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
@ -68,7 +76,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
||||
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
"""Initialize an MLPProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
@ -97,7 +105,7 @@ class IPAdapter(RawModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
state_dict: IPAdapterStateDict,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
@ -129,13 +137,11 @@ class IPAdapter(RawModel):
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
@ -146,7 +152,7 @@ class IPAdapter(RawModel):
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
@ -157,9 +163,7 @@ class IPAdapterPlus(IPAdapter):
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
@ -174,14 +178,14 @@ class IPAdapterPlus(IPAdapter):
|
||||
class IPAdapterFull(IPAdapterPlus):
|
||||
"""IP-Adapter Plus with full features."""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
class IPAdapterPlusXL(IPAdapterPlus):
|
||||
"""IP-Adapter Plus for SDXL."""
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
@ -195,7 +199,19 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||
|
||||
if ip_adapter_ckpt_path.endswith("safetensors"):
|
||||
state_dict = {"ip_adapter": {}, "image_proj": {}}
|
||||
model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt")
|
||||
for key in model.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
|
||||
if key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
|
||||
else:
|
||||
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
|
||||
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||
|
||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
|
Reference in New Issue
Block a user