mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve robustness of check for IPAdapter vs IPAdapterPlus.
This commit is contained in:
parent
781e8521d5
commit
a22c8cb3a1
@ -1,8 +1,9 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
@ -67,7 +68,7 @@ class IPAdapter:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_ckpt_path: str,
|
||||
state_dict: dict[torch.Tensor],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
@ -75,12 +76,11 @@ class IPAdapter:
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
|
||||
self._num_tokens = num_tokens
|
||||
|
||||
self._clip_image_processor = CLIPImageProcessor()
|
||||
|
||||
self._state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
|
||||
self._state_dict = state_dict
|
||||
|
||||
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
|
||||
|
||||
@ -198,3 +198,18 @@ class IPAdapterPlus(IPAdapter):
|
||||
]
|
||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
|
||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
||||
# contains.
|
||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
||||
|
||||
if is_plus:
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
|
@ -5,7 +5,11 @@ from typing import Any, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.ip_adapter.ip_adapter import (
|
||||
IPAdapter,
|
||||
IPAdapterPlus,
|
||||
build_ip_adapter,
|
||||
)
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
@ -79,16 +83,9 @@ class IPAdapterModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
# TODO(ryand): Checking for "plus" in the file path is fragile. It should be possible to infer whether this is a
|
||||
# "plus" variant by loading the state_dict.
|
||||
if "plus" in str(self.model_path):
|
||||
return IPAdapterPlus(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
return IPAdapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
return build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
|
Loading…
Reference in New Issue
Block a user