Initial (barely) working version of IP-Adapter model management.

This commit is contained in:
Ryan Dick 2023-09-12 19:09:10 -04:00
parent 0d823901ef
commit 3ee9a21647
8 changed files with 182 additions and 85 deletions

View File

@ -16,10 +16,10 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.primitives import ImageField
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
"ip-adapter_sd15",
"ip-adapter-plus_sd15",
"ip-adapter-plus-face_sd15",
"ip-adapter_sdxl",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
@ -52,7 +52,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
default="ip-adapter_sd15.bin",
description="The name of the IP-Adapter model.",
title="IP-Adapter Model",
)
@ -65,12 +65,8 @@ class IPAdapterInvocation(BaseInvocation):
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=(
context.services.configuration.get_config().root_dir / self.ip_adapter_model
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=self.image_encoder_model,
weight=self.weight,
),
)

View File

@ -403,14 +403,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
) -> IPAdapterData:
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
if ip_adapter is None:
return None
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
ip_adapter_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model,
model_type=ModelType.IPAdapter,
base_model=BaseModelType.StableDiffusion1, # HACK(ryand): Pass this in properly
context=context,
)
)
return IPAdapterData(
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object.
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
ip_adapter_model=ip_adapter_model,
image=input_image,
weight=ip_adapter.weight,
)
@ -543,6 +552,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
exit_stack=exit_stack,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler(

View File

@ -7,23 +7,33 @@ import warnings
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, List, Dict, Callable, Union, Set
from typing import Callable, Dict, List, Optional, Set, Union
import requests
import torch
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from huggingface_hub import HfApi, HfFolder, hf_hub_url
from omegaconf import OmegaConf
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
ModelManager,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_management.model_probe import (
ModelProbe,
ModelProbeInfo,
SchedulerPredictionType,
)
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
@ -308,6 +318,7 @@ class ModelInstall(object):
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
# TODO(ryand): Add special handling for ip_adapter?
else:
for suffix in ["safetensors", "bin"]:
if f"pytorch_lora_weights.{suffix}" in files:
@ -534,14 +545,17 @@ def hf_download_with_resume(
logger.info(f"{model_name}: Downloading...")
try:
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)

View File

@ -2,6 +2,7 @@
# and modified as needed
from contextlib import contextmanager
from typing import Optional
import torch
from diffusers.models import UNet2DConditionModel
@ -45,36 +46,74 @@ class IPAdapter:
def __init__(
self,
unet: UNet2DConditionModel,
image_encoder_path: str,
ip_adapter_ckpt_path: str,
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
):
self._unet = unet
self._device = device
self.device = device
self.dtype = dtype
self._image_encoder_path = image_encoder_path
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
self._num_tokens = num_tokens
self._attn_processors = self._prepare_attention_processors()
# load image encoder
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
self._device, dtype=torch.float16
self.device, dtype=self.dtype
)
self._clip_image_processor = CLIPImageProcessor()
# image proj model
self._image_proj_model = self._init_image_proj_model()
self._load_weights()
# Fields to be initialized later in initialize().
self._unet = None
self._image_proj_model = None
self._attn_processors = None
self._state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
def is_initialized(self):
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
def initialize(self, unet: UNet2DConditionModel):
"""Finish the model initialization process.
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
access to the UNet model to be patched, which can not easily be passed to __init__ by the model manager.
Args:
unet (UNet2DConditionModel): The UNet whose attention blocks will be patched by this IP-Adapter.
"""
if self.is_initialized():
raise Exception("IPAdapter has already been initialized.")
self._unet = unet
self._image_proj_model = self._init_image_proj_model()
self._attn_processors = self._prepare_attention_processors()
# Copy the weights from the _state_dict into the models.
self._image_proj_model.load_state_dict(self._state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._state_dict = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype
for model in [self._image_encoder, self._image_proj_model, self._attn_processors]:
# If this is called before initialize(), then some models will still be None. We just update the non-None
# models.
if model is not None:
model.to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=self._image_encoder.config.projection_dim,
clip_extra_context_tokens=self._num_tokens,
).to(self._device, dtype=torch.float16)
).to(self.device, dtype=self.dtype)
return image_proj_model
def _prepare_attention_processors(self):
@ -99,7 +138,7 @@ class IPAdapter:
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self._device, dtype=torch.float16)
).to(self.device, dtype=self.dtype)
return attn_procs
@contextmanager
@ -109,30 +148,36 @@ class IPAdapter:
Yields:
None
"""
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.apply_ip_adapter_attention().")
orig_attn_processors = self._unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because set_attn_processor(...) actually
# pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
self._unet.set_attn_processor(self._attn_processors)
self._unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
self._unet.set_attn_processor(orig_attn_processors)
def _load_weights(self):
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
self._image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().")
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds
clip_image_embeds = self._image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.set_scale().")
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@ -151,15 +196,18 @@ class IPAdapterPlus(IPAdapter):
embedding_dim=self._image_encoder.config.hidden_size,
output_dim=self._unet.config.cross_attention_dim,
ff_mult=4,
).to(self._device, dtype=torch.float16)
).to(self.device, dtype=self.dtype)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().")
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self._device, dtype=torch.float16)
clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self._image_encoder(

View File

@ -1001,8 +1001,8 @@ class ModelManager(object):
new_models_found = True
except DuplicateModelException as e:
self.logger.warning(e)
except InvalidModelException:
self.logger.warning(f"Not a valid model: {model_path}")
except InvalidModelException as e:
self.logger.warning(f"Not a valid model: {model_path}. {e}")
except NotImplementedError as e:
self.logger.warning(e)

View File

@ -61,7 +61,7 @@ class ModelType(str, Enum):
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ipadapter"
IPAdapter = "ip_adapter"
class SubModelType(str, Enum):

View File

@ -1,24 +1,31 @@
import os
import typing
from enum import Enum
from typing import Any, Optional
from typing import Any, Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
classproperty,
)
class IPAdapterModelFormat(Enum):
# The 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
Tencent = "tencent"
class IPAdapterModelFormat(str, Enum):
# Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
Checkpoint = "checkpoint"
class IPAdapterModel(ModelBase):
class CheckpointConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.Checkpoint]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
@ -31,23 +38,58 @@ class IPAdapterModel(ModelBase):
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
raise NotImplementedError()
if os.path.isfile(path):
if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")):
return IPAdapterModelFormat.Checkpoint
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
raise NotImplementedError()
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
raise NotImplementedError()
# TODO(ryand): Update self.model_size when the model is loaded from disk.
return self.model_size
def _get_text_encoder_path(self) -> str:
# TODO(ryand): Move the CLIP image encoder to its own model directory.
return os.path.join(os.path.dirname(self.model_path), "image_encoder")
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
raise NotImplementedError()
# TODO(ryand): Update IPAdapter to accept a torch_dtype param.
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
# "plus" variant by loading the state_dict.
if "plus" in str(self.model_path):
return IPAdapterPlus(
image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu"
)
else:
return IPAdapter(
image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu"
)
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.Checkpoint:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -171,8 +171,7 @@ class ControlNetData:
@dataclass
class IPAdapterData:
ip_adapter_model: str = Field(default=None)
image_encoder_model: str = Field(default=None)
ip_adapter_model: IPAdapter = Field(default=None)
image: PIL.Image = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
# weight: Union[float, List[float]] = Field(default=1.0)
@ -417,27 +416,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return latents, attention_map_saver
if ip_adapter_data is not None:
# Initialize IPAdapter
# TODO(ryand): Refactor to use model management for the IP-Adapter.
if "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
else:
ip_adapter = IPAdapter(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
ip_adapter.set_scale(ip_adapter_data.weight)
if not ip_adapter_data.ip_adapter_model.is_initialized():
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
ip_adapter_data.ip_adapter_model.initialize(self.unet)
ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_data.ip_adapter_model.get_image_embeds(
ip_adapter_data.image
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
@ -451,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention()
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention()
else:
attn_ctx = nullcontext()