Initial (barely) working version of IP-Adapter model management.

This commit is contained in:
Ryan Dick 2023-09-12 19:09:10 -04:00
parent 0d823901ef
commit 3ee9a21647
8 changed files with 182 additions and 85 deletions

View File

@ -16,10 +16,10 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
IP_ADAPTER_MODELS = Literal[ IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", "ip-adapter_sd15",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin", "ip-adapter-plus_sd15",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin", "ip-adapter-plus-face_sd15",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin", "ip-adapter_sdxl",
] ]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[ IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
@ -52,7 +52,7 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.") image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IP_ADAPTER_MODELS = InputField( ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", default="ip-adapter_sd15.bin",
description="The name of the IP-Adapter model.", description="The name of the IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
) )
@ -65,12 +65,8 @@ class IPAdapterInvocation(BaseInvocation):
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=( ip_adapter_model=self.ip_adapter_model,
context.services.configuration.get_config().root_dir / self.ip_adapter_model image_encoder_model=self.image_encoder_model,
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
weight=self.weight, weight=self.weight,
), ),
) )

View File

@ -403,14 +403,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
self, self,
context: InvocationContext, context: InvocationContext,
ip_adapter: Optional[IPAdapterField], ip_adapter: Optional[IPAdapterField],
) -> IPAdapterData: exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
if ip_adapter is None: if ip_adapter is None:
return None return None
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name) input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
ip_adapter_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model,
model_type=ModelType.IPAdapter,
base_model=BaseModelType.StableDiffusion1, # HACK(ryand): Pass this in properly
context=context,
)
)
return IPAdapterData( return IPAdapterData(
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object. ip_adapter_model=ip_adapter_model,
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
image=input_image, image=input_image,
weight=ip_adapter.weight, weight=ip_adapter.weight,
) )
@ -543,6 +552,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data( ip_adapter_data = self.prep_ip_adapter_data(
context=context, context=context,
ip_adapter=self.ip_adapter, ip_adapter=self.ip_adapter,
exit_stack=exit_stack,
) )
num_inference_steps, timesteps, init_timestep = self.init_scheduler( num_inference_steps, timesteps, init_timestep = self.init_scheduler(

View File

@ -7,23 +7,33 @@ import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional, List, Dict, Callable, Union, Set from typing import Callable, Dict, List, Optional, Set, Union
import requests import requests
import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
import torch from huggingface_hub import HfApi, HfFolder, hf_hub_url
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management import (
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo AddModelResult,
BaseModelType,
ModelManager,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_management.model_probe import (
ModelProbe,
ModelProbeInfo,
SchedulerPredictionType,
)
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -308,6 +318,7 @@ class ModelInstall(object):
location = self._download_hf_pipeline(repo_id, staging) # pipeline location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files: elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging) location = self._download_hf_model(repo_id, files, staging)
# TODO(ryand): Add special handling for ip_adapter?
else: else:
for suffix in ["safetensors", "bin"]: for suffix in ["safetensors", "bin"]:
if f"pytorch_lora_weights.{suffix}" in files: if f"pytorch_lora_weights.{suffix}" in files:
@ -534,14 +545,17 @@ def hf_download_with_resume(
logger.info(f"{model_name}: Downloading...") logger.info(f"{model_name}: Downloading...")
try: try:
with open(model_dest, open_mode) as file, tqdm( with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name, desc=model_name,
initial=exist_size, initial=exist_size,
total=total + exist_size, total=total + exist_size,
unit="iB", unit="iB",
unit_scale=True, unit_scale=True,
unit_divisor=1000, unit_divisor=1000,
) as bar: ) as bar,
):
for data in resp.iter_content(chunk_size=1024): for data in resp.iter_content(chunk_size=1024):
size = file.write(data) size = file.write(data)
bar.update(size) bar.update(size)

View File

@ -2,6 +2,7 @@
# and modified as needed # and modified as needed
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional
import torch import torch
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
@ -45,36 +46,74 @@ class IPAdapter:
def __init__( def __init__(
self, self,
unet: UNet2DConditionModel,
image_encoder_path: str, image_encoder_path: str,
ip_adapter_ckpt_path: str, ip_adapter_ckpt_path: str,
device: torch.device, device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4, num_tokens: int = 4,
): ):
self._unet = unet self.device = device
self._device = device self.dtype = dtype
self._image_encoder_path = image_encoder_path self._image_encoder_path = image_encoder_path
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
self._num_tokens = num_tokens self._num_tokens = num_tokens
self._attn_processors = self._prepare_attention_processors()
# load image encoder
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to( self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
self._device, dtype=torch.float16 self.device, dtype=self.dtype
) )
self._clip_image_processor = CLIPImageProcessor() self._clip_image_processor = CLIPImageProcessor()
# image proj model
self._image_proj_model = self._init_image_proj_model()
self._load_weights() # Fields to be initialized later in initialize().
self._unet = None
self._image_proj_model = None
self._attn_processors = None
self._state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
def is_initialized(self):
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
def initialize(self, unet: UNet2DConditionModel):
"""Finish the model initialization process.
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
access to the UNet model to be patched, which can not easily be passed to __init__ by the model manager.
Args:
unet (UNet2DConditionModel): The UNet whose attention blocks will be patched by this IP-Adapter.
"""
if self.is_initialized():
raise Exception("IPAdapter has already been initialized.")
self._unet = unet
self._image_proj_model = self._init_image_proj_model()
self._attn_processors = self._prepare_attention_processors()
# Copy the weights from the _state_dict into the models.
self._image_proj_model.load_state_dict(self._state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._state_dict = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype
for model in [self._image_encoder, self._image_proj_model, self._attn_processors]:
# If this is called before initialize(), then some models will still be None. We just update the non-None
# models.
if model is not None:
model.to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self): def _init_image_proj_model(self):
image_proj_model = ImageProjModel( image_proj_model = ImageProjModel(
cross_attention_dim=self._unet.config.cross_attention_dim, cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=self._image_encoder.config.projection_dim, clip_embeddings_dim=self._image_encoder.config.projection_dim,
clip_extra_context_tokens=self._num_tokens, clip_extra_context_tokens=self._num_tokens,
).to(self._device, dtype=torch.float16) ).to(self.device, dtype=self.dtype)
return image_proj_model return image_proj_model
def _prepare_attention_processors(self): def _prepare_attention_processors(self):
@ -99,7 +138,7 @@ class IPAdapter:
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
scale=1.0, scale=1.0,
).to(self._device, dtype=torch.float16) ).to(self.device, dtype=self.dtype)
return attn_procs return attn_procs
@contextmanager @contextmanager
@ -109,30 +148,36 @@ class IPAdapter:
Yields: Yields:
None None
""" """
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.apply_ip_adapter_attention().")
orig_attn_processors = self._unet.attn_processors orig_attn_processors = self._unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because set_attn_processor(...) actually
# pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try: try:
self._unet.set_attn_processor(self._attn_processors) self._unet.set_attn_processor(ip_adapter_attn_processors)
yield None yield None
finally: finally:
self._unet.set_attn_processor(orig_attn_processors) self._unet.set_attn_processor(orig_attn_processors)
def _load_weights(self):
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
self._image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image): def get_image_embeds(self, pil_image):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().")
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
pil_image = [pil_image] pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds clip_image_embeds = self._image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds) image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds)) uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale): def set_scale(self, scale):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.set_scale().")
for attn_processor in self._attn_processors.values(): for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor): if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale attn_processor.scale = scale
@ -151,15 +196,18 @@ class IPAdapterPlus(IPAdapter):
embedding_dim=self._image_encoder.config.hidden_size, embedding_dim=self._image_encoder.config.hidden_size,
output_dim=self._unet.config.cross_attention_dim, output_dim=self._unet.config.cross_attention_dim,
ff_mult=4, ff_mult=4,
).to(self._device, dtype=torch.float16) ).to(self.device, dtype=self.dtype)
return image_proj_model return image_proj_model
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image): def get_image_embeds(self, pil_image):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().")
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
pil_image = [pil_image] pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self._device, dtype=torch.float16) clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds) image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self._image_encoder( uncond_clip_image_embeds = self._image_encoder(

View File

@ -1001,8 +1001,8 @@ class ModelManager(object):
new_models_found = True new_models_found = True
except DuplicateModelException as e: except DuplicateModelException as e:
self.logger.warning(e) self.logger.warning(e)
except InvalidModelException: except InvalidModelException as e:
self.logger.warning(f"Not a valid model: {model_path}") self.logger.warning(f"Not a valid model: {model_path}. {e}")
except NotImplementedError as e: except NotImplementedError as e:
self.logger.warning(e) self.logger.warning(e)

View File

@ -61,7 +61,7 @@ class ModelType(str, Enum):
Lora = "lora" Lora = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ipadapter" IPAdapter = "ip_adapter"
class SubModelType(str, Enum): class SubModelType(str, Enum):

View File

@ -1,24 +1,31 @@
import os import os
import typing
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Literal, Optional
import torch import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models.base import ( from invokeai.backend.model_management.models.base import (
BaseModelType, BaseModelType,
InvalidModelException,
ModelBase, ModelBase,
ModelConfigBase,
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
) )
class IPAdapterModelFormat(Enum): class IPAdapterModelFormat(str, Enum):
# The 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter) # Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
Tencent = "tencent" Checkpoint = "checkpoint"
class IPAdapterModel(ModelBase): class IPAdapterModel(ModelBase):
class CheckpointConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.Checkpoint]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type) super().__init__(model_path, base_model, model_type)
@ -31,23 +38,58 @@ class IPAdapterModel(ModelBase):
if not os.path.exists(path): if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.") raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
raise NotImplementedError() if os.path.isfile(path):
if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")):
return IPAdapterModelFormat.Checkpoint
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty @classproperty
def save_to_config(cls) -> bool: def save_to_config(cls) -> bool:
raise NotImplementedError() return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int: def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None: if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.") raise ValueError("There are no child models in an IP-Adapter model.")
raise NotImplementedError() # TODO(ryand): Update self.model_size when the model is loaded from disk.
return self.model_size
def _get_text_encoder_path(self) -> str:
# TODO(ryand): Move the CLIP image encoder to its own model directory.
return os.path.join(os.path.dirname(self.model_path), "image_encoder")
def get_model( def get_model(
self, self,
torch_dtype: Optional[torch.dtype], torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None, child_type: Optional[SubModelType] = None,
) -> Any: ) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None: if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.") raise ValueError("There are no child models in an IP-Adapter model.")
raise NotImplementedError()
# TODO(ryand): Update IPAdapter to accept a torch_dtype param.
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
# "plus" variant by loading the state_dict.
if "plus" in str(self.model_path):
return IPAdapterPlus(
image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu"
)
else:
return IPAdapter(
image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu"
)
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.Checkpoint:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -171,8 +171,7 @@ class ControlNetData:
@dataclass @dataclass
class IPAdapterData: class IPAdapterData:
ip_adapter_model: str = Field(default=None) ip_adapter_model: IPAdapter = Field(default=None)
image_encoder_model: str = Field(default=None)
image: PIL.Image = Field(default=None) image: PIL.Image = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...) # TODO: change to polymorphic so can do different weights per step (once implemented...)
# weight: Union[float, List[float]] = Field(default=1.0) # weight: Union[float, List[float]] = Field(default=1.0)
@ -417,27 +416,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return latents, attention_map_saver return latents, attention_map_saver
if ip_adapter_data is not None: if ip_adapter_data is not None:
# Initialize IPAdapter if not ip_adapter_data.ip_adapter_model.is_initialized():
# TODO(ryand): Refactor to use model management for the IP-Adapter. # TODO(ryan): Do we need to initialize every time? How long does initialize take?
if "plus" in ip_adapter_data.ip_adapter_model: ip_adapter_data.ip_adapter_model.initialize(self.unet)
ip_adapter = IPAdapterPlus( ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
else:
ip_adapter = IPAdapter(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
ip_adapter.set_scale(ip_adapter_data.weight)
# Get image embeddings from CLIP and ImageProjModel. # Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image) image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_data.ip_adapter_model.get_image_embeds(
ip_adapter_data.image
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo( conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds image_prompt_embeds, uncond_image_prompt_embeds
) )
@ -451,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped. # As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention() attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention()
else: else:
attn_ctx = nullcontext() attn_ctx = nullcontext()