mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
IP-Adapter Safetensor Support (#6041)
## Summary This PR adds support for IP Adapter safetensor files for direct usage inside InvokeAI. # TEST You can download the [Composition Adapters](https://huggingface.co/ostris/ip-composition-adapter) which weren't previously supported in Invoke and try them out. Every other IP Adapter model should work too. If you pick a Safetensor IP Adapter model, you will also need to set ViT-H or ViT-G next to it. This is a raw implementation. Can refine it further based on feedback. Prompt: `Spiderman holding a bunny` -- Exact same composition as the adapter image. 
This commit is contained in:
commit
7da04b8333
@ -1,21 +1,22 @@
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Union
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
IPAdapterCheckpointConfig,
|
||||||
|
IPAdapterInvokeAIConfig,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
@ -48,12 +49,15 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
|
||||||
ip_adapter_model: ModelIdentifierField = InputField(
|
ip_adapter_model: ModelIdentifierField = InputField(
|
||||||
description="The IP-Adapter model.",
|
description="The IP-Adapter model.",
|
||||||
title="IP-Adapter Model",
|
title="IP-Adapter Model",
|
||||||
@ -61,7 +65,11 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
ui_order=-1,
|
ui_order=-1,
|
||||||
ui_type=UIType.IPAdapterModel,
|
ui_type=UIType.IPAdapterModel,
|
||||||
)
|
)
|
||||||
|
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
|
||||||
|
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||||
|
default="auto",
|
||||||
|
ui_order=2,
|
||||||
|
)
|
||||||
weight: Union[float, List[float]] = InputField(
|
weight: Union[float, List[float]] = InputField(
|
||||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||||
)
|
)
|
||||||
@ -86,10 +94,21 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||||
|
|
||||||
|
if self.clip_vision_model == "auto":
|
||||||
|
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||||
|
|
||||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
@ -102,19 +121,25 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||||
found = False
|
|
||||||
while not found:
|
|
||||||
image_encoder_models = context.models.search_by_attrs(
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
)
|
)
|
||||||
found = len(image_encoder_models) > 0
|
|
||||||
if not found:
|
if not len(image_encoder_models) > 0:
|
||||||
context.logger.warning(
|
context.logger.warning(
|
||||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
|
||||||
|
Downloading and installing now. This may take a while."
|
||||||
)
|
)
|
||||||
context.logger.warning("Downloading and installing now. This may take a while.")
|
|
||||||
installer = context._services.model_manager.install
|
installer = context._services.model_manager.install
|
||||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||||
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
|
||||||
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(image_encoder_models) == 0:
|
||||||
|
context.logger.error("Error while fetching CLIP Vision Image Encoder")
|
||||||
assert len(image_encoder_models) == 1
|
assert len(image_encoder_models) == 1
|
||||||
|
|
||||||
return image_encoder_models[0]
|
return image_encoder_models[0]
|
||||||
|
@ -43,11 +43,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
|
||||||
DenoiseMaskOutput,
|
|
||||||
ImageOutput,
|
|
||||||
LatentsOutput,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
@ -68,12 +64,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from .baseinvocation import (
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelIdentifierField, UNetField, VAEField
|
||||||
|
|
||||||
|
@ -2,16 +2,8 @@ from typing import Any, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.controlnet_image_processors import (
|
|
||||||
CONTROLNET_MODE_VALUES,
|
|
||||||
CONTROLNET_RESIZE_VALUES,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
@ -43,6 +35,7 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
|
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||||
|
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
# and modified as needed
|
# and modified as needed
|
||||||
|
|
||||||
from typing import Optional, Union
|
import pathlib
|
||||||
|
from typing import List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
@ -13,10 +16,17 @@ from ..raw_model import RawModel
|
|||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterStateDict(TypedDict):
|
||||||
|
ip_adapter: dict[str, torch.Tensor]
|
||||||
|
image_proj: dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class ImageProjModel(torch.nn.Module):
|
class ImageProjModel(torch.nn.Module):
|
||||||
"""Image Projection Model"""
|
"""Image Projection Model"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
def __init__(
|
||||||
|
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
@ -25,7 +35,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
|
||||||
"""Initialize an ImageProjModel from a state_dict.
|
"""Initialize an ImageProjModel from a state_dict.
|
||||||
|
|
||||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
@ -45,7 +55,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds: torch.Tensor):
|
||||||
embeds = image_embeds
|
embeds = image_embeds
|
||||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||||
@ -57,7 +67,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
class MLPProjModel(torch.nn.Module):
|
class MLPProjModel(torch.nn.Module):
|
||||||
"""SD model with image prompt"""
|
"""SD model with image prompt"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
@ -68,7 +78,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||||
"""Initialize an MLPProjModel from a state_dict.
|
"""Initialize an MLPProjModel from a state_dict.
|
||||||
|
|
||||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
@ -87,7 +97,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds: torch.Tensor):
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
@ -97,7 +107,7 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
state_dict: dict[str, torch.Tensor],
|
state_dict: IPAdapterStateDict,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_tokens: int = 4,
|
num_tokens: int = 4,
|
||||||
@ -129,24 +139,27 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(
|
||||||
|
self, state_dict: dict[str, torch.Tensor]
|
||||||
|
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
|
||||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
pil_image = [pil_image]
|
|
||||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||||
|
try:
|
||||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -157,31 +170,32 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
pil_image = [pil_image]
|
|
||||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
|
||||||
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
||||||
-2
|
-2
|
||||||
]
|
]
|
||||||
|
try:
|
||||||
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterFull(IPAdapterPlus):
|
class IPAdapterFull(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus with full features."""
|
"""IP-Adapter Plus with full features."""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlusXL(IPAdapterPlus):
|
class IPAdapterPlusXL(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus for SDXL."""
|
"""IP-Adapter Plus for SDXL."""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -192,24 +206,48 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
def build_ip_adapter(
|
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
|
||||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
|
||||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
|
||||||
|
|
||||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
if ip_adapter_ckpt_path.suffix == ".safetensors":
|
||||||
|
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
|
||||||
|
for key in model.keys():
|
||||||
|
if key.startswith("image_proj."):
|
||||||
|
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
|
||||||
|
elif key.startswith("ip_adapter."):
|
||||||
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
|
||||||
|
else:
|
||||||
|
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
|
||||||
|
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def build_ip_adapter(
|
||||||
|
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
|
||||||
|
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
|
||||||
|
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
|
||||||
|
|
||||||
|
# IPAdapter (with ImageProjModel)
|
||||||
|
if "proj.weight" in state_dict["image_proj"]:
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
|
||||||
|
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
|
||||||
|
elif "proj_in.weight" in state_dict["image_proj"]:
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
if cross_attention_dim == 768:
|
if cross_attention_dim == 768:
|
||||||
# SD1 IP-Adapter Plus
|
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
|
||||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
|
||||||
elif cross_attention_dim == 2048:
|
elif cross_attention_dim == 2048:
|
||||||
# SDXL IP-Adapter Plus
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
|
||||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
|
||||||
|
# IPAdapterFull (with MLPProjModel)
|
||||||
|
elif "proj.0.weight" in state_dict["image_proj"]:
|
||||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Unrecognized IP Adapter Architectures
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||||
|
@ -9,8 +9,8 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
# FFN
|
# FFN
|
||||||
def FeedForward(dim, mult=4):
|
def FeedForward(dim: int, mult: int = 4):
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = dim * mult
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.LayerNorm(dim),
|
nn.LayerNorm(dim),
|
||||||
nn.Linear(dim, inner_dim, bias=False),
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def reshape_tensor(x, heads):
|
def reshape_tensor(x: torch.Tensor, heads: int):
|
||||||
bs, length, width = x.shape
|
bs, length, _ = x.shape
|
||||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
x = x.view(bs, length, heads, -1)
|
x = x.view(bs, length, heads, -1)
|
||||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
|
|||||||
|
|
||||||
|
|
||||||
class PerceiverAttention(nn.Module):
|
class PerceiverAttention(nn.Module):
|
||||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head**-0.5
|
self.scale = dim_head**-0.5
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
def forward(self, x, latents):
|
def forward(self, x: torch.Tensor, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): image features
|
x (torch.Tensor): image features
|
||||||
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
|
|||||||
class Resampler(nn.Module):
|
class Resampler(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim=1024,
|
dim: int = 1024,
|
||||||
depth=8,
|
depth: int = 8,
|
||||||
dim_head=64,
|
dim_head: int = 64,
|
||||||
heads=16,
|
heads: int = 16,
|
||||||
num_queries=8,
|
num_queries: int = 8,
|
||||||
embedding_dim=768,
|
embedding_dim: int = 768,
|
||||||
output_dim=1024,
|
output_dim: int = 1024,
|
||||||
ff_mult=4,
|
ff_mult: int = 4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -110,7 +110,15 @@ class Resampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
def from_state_dict(
|
||||||
|
cls,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
depth: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
heads: int = 16,
|
||||||
|
num_queries: int = 8,
|
||||||
|
ff_mult: int = 4,
|
||||||
|
):
|
||||||
"""A convenience function that initializes a Resampler from a state_dict.
|
"""A convenience function that initializes a Resampler from a state_dict.
|
||||||
|
|
||||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||||
@ -145,7 +153,7 @@ class Resampler(nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor):
|
||||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
|
@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
|||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterBaseConfig(ModelConfigBase):
|
||||||
"""Model config for IP Adaptor format models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
||||||
|
"""Model config for IP Adapter diffusers format models."""
|
||||||
|
|
||||||
image_encoder_model_id: str
|
image_encoder_model_id: str
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||||
|
"""Model config for IP Adapter checkpoint format models."""
|
||||||
|
|
||||||
|
format: Literal[ModelFormat.Checkpoint]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||||
"""Model config for CLIPVision."""
|
"""Model config for CLIPVision."""
|
||||||
|
|
||||||
@ -390,7 +403,8 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||||
|
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
],
|
],
|
||||||
|
@ -7,19 +7,13 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
AnyModel,
|
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
|
||||||
class IPAdapterInvokeAILoader(ModelLoader):
|
class IPAdapterInvokeAILoader(ModelLoader):
|
||||||
"""Class to load IP Adapter diffusers models."""
|
"""Class to load IP Adapter diffusers models."""
|
||||||
|
|
||||||
@ -32,7 +26,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
|||||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
model: RawModel = build_ip_adapter(
|
model: RawModel = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
ip_adapter_ckpt_path=model_path,
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
)
|
)
|
||||||
|
@ -230,9 +230,10 @@ class ModelProbe(object):
|
|||||||
return ModelType.LoRA
|
return ModelType.LoRA
|
||||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
|
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||||
|
return ModelType.IPAdapter
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# diffusers-ti
|
# diffusers-ti
|
||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
@ -527,8 +528,25 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing IP Adapters"""
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
checkpoint = self.checkpoint
|
||||||
|
for key in checkpoint.keys():
|
||||||
|
if not key.startswith(("image_proj.", "ip_adapter.")):
|
||||||
|
continue
|
||||||
|
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
|
||||||
|
if cross_attention_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif cross_attention_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif cross_attention_dim == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelConfigException(
|
||||||
|
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||||
|
)
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||||
@ -768,7 +786,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
# Register probe classes
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||||
|
@ -217,6 +217,7 @@
|
|||||||
"saveControlImage": "Save Control Image",
|
"saveControlImage": "Save Control Image",
|
||||||
"scribble": "scribble",
|
"scribble": "scribble",
|
||||||
"selectModel": "Select a model",
|
"selectModel": "Select a model",
|
||||||
|
"selectCLIPVisionModel": "Select a CLIP Vision model",
|
||||||
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
||||||
"showAdvanced": "Show Advanced",
|
"showAdvanced": "Show Advanced",
|
||||||
"small": "Small",
|
"small": "Small",
|
||||||
@ -655,6 +656,7 @@
|
|||||||
"install": "Install",
|
"install": "Install",
|
||||||
"installAll": "Install All",
|
"installAll": "Install All",
|
||||||
"installRepo": "Install Repo",
|
"installRepo": "Install Repo",
|
||||||
|
"ipAdapters": "IP Adapters",
|
||||||
"load": "Load",
|
"load": "Load",
|
||||||
"localOnly": "local only",
|
"localOnly": "local only",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
|
||||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import {
|
||||||
|
controlAdapterCLIPVisionModelChanged,
|
||||||
|
controlAdapterModelChanged,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -29,6 +35,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
const { modelConfig } = useControlAdapterModel(id);
|
const { modelConfig } = useControlAdapterModel(id);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
|
||||||
const mainModel = useAppSelector(selectMainModel);
|
const mainModel = useAppSelector(selectMainModel);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -49,6 +56,16 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
[dispatch, id]
|
[dispatch, id]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
if (!v?.value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
|
||||||
|
},
|
||||||
|
[dispatch, id]
|
||||||
|
);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||||
[controlAdapterType, modelConfig]
|
[controlAdapterType, modelConfig]
|
||||||
@ -71,9 +88,27 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
isLoading,
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const clipVisionOptions = useMemo<ComboboxOption[]>(
|
||||||
|
() => [
|
||||||
|
{ label: 'ViT-H', value: 'ViT-H' },
|
||||||
|
{ label: 'ViT-G', value: 'ViT-G' },
|
||||||
|
],
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const clipVisionModel = useMemo(
|
||||||
|
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
|
||||||
|
[clipVisionOptions, currentCLIPVisionModel]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<Flex sx={{ gap: 2 }}>
|
||||||
<Tooltip label={value?.description}>
|
<Tooltip label={value?.description}>
|
||||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
|
<FormControl
|
||||||
|
isDisabled={!isEnabled}
|
||||||
|
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||||
|
sx={{ width: '100%' }}
|
||||||
|
>
|
||||||
<Combobox
|
<Combobox
|
||||||
options={options}
|
options={options}
|
||||||
placeholder={t('controlnet.selectModel')}
|
placeholder={t('controlnet.selectModel')}
|
||||||
@ -83,6 +118,21 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
|
||||||
|
<FormControl
|
||||||
|
isDisabled={!isEnabled}
|
||||||
|
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||||
|
sx={{ width: 'max-content', minWidth: 28 }}
|
||||||
|
>
|
||||||
|
<Combobox
|
||||||
|
options={clipVisionOptions}
|
||||||
|
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||||
|
value={clipVisionModel}
|
||||||
|
onChange={onCLIPVisionModelChange}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
selectControlAdapterById,
|
||||||
|
selectControlAdaptersSlice,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
|
export const useControlAdapterCLIPVisionModel = (id: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||||
|
const cn = selectControlAdapterById(controlAdapters, id);
|
||||||
|
if (cn && cn?.type === 'ip_adapter') {
|
||||||
|
return cn.clipVisionModel;
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
[id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const clipVisionModel = useAppSelector(selector);
|
||||||
|
|
||||||
|
return clipVisionModel;
|
||||||
|
};
|
@ -14,6 +14,7 @@ import { v4 as uuidv4 } from 'uuid';
|
|||||||
import { controlAdapterImageProcessed } from './actions';
|
import { controlAdapterImageProcessed } from './actions';
|
||||||
import { CONTROLNET_PROCESSORS } from './constants';
|
import { CONTROLNET_PROCESSORS } from './constants';
|
||||||
import type {
|
import type {
|
||||||
|
CLIPVisionModel,
|
||||||
ControlAdapterConfig,
|
ControlAdapterConfig,
|
||||||
ControlAdapterProcessorType,
|
ControlAdapterProcessorType,
|
||||||
ControlAdaptersState,
|
ControlAdaptersState,
|
||||||
@ -244,6 +245,13 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
}
|
}
|
||||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||||
},
|
},
|
||||||
|
controlAdapterCLIPVisionModelChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||||
|
) => {
|
||||||
|
const { id, clipVisionModel } = action.payload;
|
||||||
|
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
|
||||||
|
},
|
||||||
controlAdapterResizeModeChanged: (
|
controlAdapterResizeModeChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -381,6 +389,7 @@ export const {
|
|||||||
controlAdapterProcessedImageChanged,
|
controlAdapterProcessedImageChanged,
|
||||||
controlAdapterIsEnabledChanged,
|
controlAdapterIsEnabledChanged,
|
||||||
controlAdapterModelChanged,
|
controlAdapterModelChanged,
|
||||||
|
controlAdapterCLIPVisionModelChanged,
|
||||||
controlAdapterWeightChanged,
|
controlAdapterWeightChanged,
|
||||||
controlAdapterBeginStepPctChanged,
|
controlAdapterBeginStepPctChanged,
|
||||||
controlAdapterEndStepPctChanged,
|
controlAdapterEndStepPctChanged,
|
||||||
|
@ -243,12 +243,15 @@ export type T2IAdapterConfig = {
|
|||||||
shouldAutoConfig: boolean;
|
shouldAutoConfig: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
|
||||||
|
|
||||||
export type IPAdapterConfig = {
|
export type IPAdapterConfig = {
|
||||||
type: 'ip_adapter';
|
type: 'ip_adapter';
|
||||||
id: string;
|
id: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
controlImage: string | null;
|
controlImage: string | null;
|
||||||
model: ParameterIPAdapterModel | null;
|
model: ParameterIPAdapterModel | null;
|
||||||
|
clipVisionModel: CLIPVisionModel;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
|
@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
|||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
controlImage: null,
|
controlImage: null,
|
||||||
model: null,
|
model: null,
|
||||||
|
clipVisionModel: 'ViT-H',
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
|
@ -372,6 +372,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||||
|
clipVisionModel: 'ViT-H',
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
|
@ -53,7 +53,7 @@ export const ModelView = () => {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{data.type === 'ip_adapter' && (
|
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||||
|
|
||||||
assert(controlImage, 'IP Adapter image is required');
|
assert(controlImage, 'IP Adapter image is required');
|
||||||
|
|
||||||
@ -58,6 +58,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
image: {
|
image: {
|
||||||
@ -83,7 +84,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||||
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
|
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
||||||
|
|
||||||
assert(model, 'IP Adapter model is required');
|
assert(model, 'IP Adapter model is required');
|
||||||
|
|
||||||
@ -99,6 +100,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
|
clip_vision_model: clipVisionModel,
|
||||||
weight,
|
weight,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
|
File diff suppressed because one or more lines are too long
@ -46,7 +46,7 @@ export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
|||||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||||
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
|
Loading…
Reference in New Issue
Block a user