From 3ee9a216478cdb42bddf243795985c507e99ee62 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 12 Sep 2023 19:09:10 -0400 Subject: [PATCH] Initial (barely) working version of IP-Adapter model management. --- invokeai/app/invocations/ip_adapter.py | 18 ++-- invokeai/app/invocations/latent.py | 16 +++- .../backend/install/model_install_backend.py | 44 ++++++--- invokeai/backend/ip_adapter/ip_adapter.py | 92 ++++++++++++++----- .../backend/model_management/model_manager.py | 4 +- .../backend/model_management/models/base.py | 2 +- .../model_management/models/ip_adapter.py | 60 ++++++++++-- .../stable_diffusion/diffusers_pipeline.py | 31 ++----- 8 files changed, 182 insertions(+), 85 deletions(-) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index d73941c376..89d440b6ea 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -16,10 +16,10 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.primitives import ImageField IP_ADAPTER_MODELS = Literal[ - "models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", - "models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin", - "models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin", - "models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin", + "ip-adapter_sd15", + "ip-adapter-plus_sd15", + "ip-adapter-plus-face_sd15", + "ip-adapter_sdxl", ] IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[ @@ -52,7 +52,7 @@ class IPAdapterInvocation(BaseInvocation): # Inputs image: ImageField = InputField(description="The IP-Adapter image prompt.") ip_adapter_model: IP_ADAPTER_MODELS = InputField( - default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", + default="ip-adapter_sd15.bin", description="The name of the IP-Adapter model.", title="IP-Adapter Model", ) @@ -65,12 +65,8 @@ class IPAdapterInvocation(BaseInvocation): return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, - ip_adapter_model=( - context.services.configuration.get_config().root_dir / self.ip_adapter_model - ).as_posix(), - image_encoder_model=( - context.services.configuration.get_config().root_dir / self.image_encoder_model - ).as_posix(), + ip_adapter_model=self.ip_adapter_model, + image_encoder_model=self.image_encoder_model, weight=self.weight, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fbad7a8988..d63e16df8c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -403,14 +403,23 @@ class DenoiseLatentsInvocation(BaseInvocation): self, context: InvocationContext, ip_adapter: Optional[IPAdapterField], - ) -> IPAdapterData: + exit_stack: ExitStack, + ) -> Optional[IPAdapterData]: if ip_adapter is None: return None input_image = context.services.images.get_pil_image(ip_adapter.image.image_name) + + ip_adapter_model = exit_stack.enter_context( + context.services.model_manager.get_model( + model_name=ip_adapter.ip_adapter_model, + model_type=ModelType.IPAdapter, + base_model=BaseModelType.StableDiffusion1, # HACK(ryand): Pass this in properly + context=context, + ) + ) return IPAdapterData( - ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object. - image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object. + ip_adapter_model=ip_adapter_model, image=input_image, weight=ip_adapter.weight, ) @@ -543,6 +552,7 @@ class DenoiseLatentsInvocation(BaseInvocation): ip_adapter_data = self.prep_ip_adapter_data( context=context, ip_adapter=self.ip_adapter, + exit_stack=exit_stack, ) num_inference_steps, timesteps, init_timestep = self.init_scheduler( diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index e41783ab09..711877ea56 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -7,23 +7,33 @@ import warnings from dataclasses import dataclass, field from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, List, Dict, Callable, Union, Set +from typing import Callable, Dict, List, Optional, Set, Union import requests +import torch from diffusers import DiffusionPipeline from diffusers import logging as dlogging -import torch -from huggingface_hub import hf_hub_url, HfFolder, HfApi +from huggingface_hub import HfApi, HfFolder, hf_hub_url from omegaconf import OmegaConf from tqdm import tqdm import invokeai.configs as configs - from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult -from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo +from invokeai.backend.model_management import ( + AddModelResult, + BaseModelType, + ModelManager, + ModelType, + ModelVariantType, +) +from invokeai.backend.model_management.model_probe import ( + ModelProbe, + ModelProbeInfo, + SchedulerPredictionType, +) from invokeai.backend.util import download_with_resume -from invokeai.backend.util.devices import torch_dtype, choose_torch_device +from invokeai.backend.util.devices import choose_torch_device, torch_dtype + from ..util.logging import InvokeAILogger warnings.filterwarnings("ignore") @@ -308,6 +318,7 @@ class ModelInstall(object): location = self._download_hf_pipeline(repo_id, staging) # pipeline elif "unet/model.onnx" in files: location = self._download_hf_model(repo_id, files, staging) + # TODO(ryand): Add special handling for ip_adapter? else: for suffix in ["safetensors", "bin"]: if f"pytorch_lora_weights.{suffix}" in files: @@ -534,14 +545,17 @@ def hf_download_with_resume( logger.info(f"{model_name}: Downloading...") try: - with open(model_dest, open_mode) as file, tqdm( - desc=model_name, - initial=exist_size, - total=total + exist_size, - unit="iB", - unit_scale=True, - unit_divisor=1000, - ) as bar: + with ( + open(model_dest, open_mode) as file, + tqdm( + desc=model_name, + initial=exist_size, + total=total + exist_size, + unit="iB", + unit_scale=True, + unit_divisor=1000, + ) as bar, + ): for data in resp.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 9d62ce58a3..7f95750aaf 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -2,6 +2,7 @@ # and modified as needed from contextlib import contextmanager +from typing import Optional import torch from diffusers.models import UNet2DConditionModel @@ -45,36 +46,74 @@ class IPAdapter: def __init__( self, - unet: UNet2DConditionModel, image_encoder_path: str, ip_adapter_ckpt_path: str, device: torch.device, + dtype: torch.dtype = torch.float16, num_tokens: int = 4, ): - self._unet = unet - self._device = device + self.device = device + self.dtype = dtype + self._image_encoder_path = image_encoder_path self._ip_adapter_ckpt_path = ip_adapter_ckpt_path self._num_tokens = num_tokens - self._attn_processors = self._prepare_attention_processors() - - # load image encoder self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to( - self._device, dtype=torch.float16 + self.device, dtype=self.dtype ) self._clip_image_processor = CLIPImageProcessor() - # image proj model - self._image_proj_model = self._init_image_proj_model() - self._load_weights() + # Fields to be initialized later in initialize(). + self._unet = None + self._image_proj_model = None + self._attn_processors = None + + self._state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu") + + def is_initialized(self): + return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None + + def initialize(self, unet: UNet2DConditionModel): + """Finish the model initialization process. + + HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires + access to the UNet model to be patched, which can not easily be passed to __init__ by the model manager. + + Args: + unet (UNet2DConditionModel): The UNet whose attention blocks will be patched by this IP-Adapter. + """ + if self.is_initialized(): + raise Exception("IPAdapter has already been initialized.") + + self._unet = unet + self._image_proj_model = self._init_image_proj_model() + self._attn_processors = self._prepare_attention_processors() + + # Copy the weights from the _state_dict into the models. + self._image_proj_model.load_state_dict(self._state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self._attn_processors.values()) + ip_layers.load_state_dict(self._state_dict["ip_adapter"]) + + self._state_dict = None + + def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): + self.device = device + if dtype is not None: + self.dtype = dtype + + for model in [self._image_encoder, self._image_proj_model, self._attn_processors]: + # If this is called before initialize(), then some models will still be None. We just update the non-None + # models. + if model is not None: + model.to(device=self.device, dtype=self.dtype) def _init_image_proj_model(self): image_proj_model = ImageProjModel( cross_attention_dim=self._unet.config.cross_attention_dim, clip_embeddings_dim=self._image_encoder.config.projection_dim, clip_extra_context_tokens=self._num_tokens, - ).to(self._device, dtype=torch.float16) + ).to(self.device, dtype=self.dtype) return image_proj_model def _prepare_attention_processors(self): @@ -99,7 +138,7 @@ class IPAdapter: hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, - ).to(self._device, dtype=torch.float16) + ).to(self.device, dtype=self.dtype) return attn_procs @contextmanager @@ -109,30 +148,36 @@ class IPAdapter: Yields: None """ + if not self.is_initialized(): + raise Exception("Call IPAdapter.initialize() before calling IPAdapter.apply_ip_adapter_attention().") + orig_attn_processors = self._unet.attn_processors + # Make a (moderately-) shallow copy of the self._attn_processors dict, because set_attn_processor(...) actually + # pops elements from the passed dict. + ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()} try: - self._unet.set_attn_processor(self._attn_processors) + self._unet.set_attn_processor(ip_adapter_attn_processors) yield None finally: self._unet.set_attn_processor(orig_attn_processors) - def _load_weights(self): - state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu") - self._image_proj_model.load_state_dict(state_dict["image_proj"]) - ip_layers = torch.nn.ModuleList(self._attn_processors.values()) - ip_layers.load_state_dict(state_dict["ip_adapter"]) - @torch.inference_mode() def get_image_embeds(self, pil_image): + if not self.is_initialized(): + raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().") + if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds + clip_image_embeds = self._image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds image_prompt_embeds = self._image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): + if not self.is_initialized(): + raise Exception("Call IPAdapter.initialize() before calling IPAdapter.set_scale().") + for attn_processor in self._attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @@ -151,15 +196,18 @@ class IPAdapterPlus(IPAdapter): embedding_dim=self._image_encoder.config.hidden_size, output_dim=self._unet.config.cross_attention_dim, ff_mult=4, - ).to(self._device, dtype=torch.float16) + ).to(self.device, dtype=self.dtype) return image_proj_model @torch.inference_mode() def get_image_embeds(self, pil_image): + if not self.is_initialized(): + raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().") + if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self._device, dtype=torch.float16) + clip_image = clip_image.to(self.device, dtype=self.dtype) clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self._image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self._image_encoder( diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index d746a83a9e..7bb188cb4e 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -1001,8 +1001,8 @@ class ModelManager(object): new_models_found = True except DuplicateModelException as e: self.logger.warning(e) - except InvalidModelException: - self.logger.warning(f"Not a valid model: {model_path}") + except InvalidModelException as e: + self.logger.warning(f"Not a valid model: {model_path}. {e}") except NotImplementedError as e: self.logger.warning(e) diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 16b6bc26a6..0bff479412 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -61,7 +61,7 @@ class ModelType(str, Enum): Lora = "lora" ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" - IPAdapter = "ipadapter" + IPAdapter = "ip_adapter" class SubModelType(str, Enum): diff --git a/invokeai/backend/model_management/models/ip_adapter.py b/invokeai/backend/model_management/models/ip_adapter.py index 028f358aaa..01a97b6dfc 100644 --- a/invokeai/backend/model_management/models/ip_adapter.py +++ b/invokeai/backend/model_management/models/ip_adapter.py @@ -1,24 +1,31 @@ import os +import typing from enum import Enum -from typing import Any, Optional +from typing import Any, Literal, Optional import torch +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models.base import ( BaseModelType, + InvalidModelException, ModelBase, + ModelConfigBase, ModelType, SubModelType, classproperty, ) -class IPAdapterModelFormat(Enum): - # The 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter) - Tencent = "tencent" +class IPAdapterModelFormat(str, Enum): + # Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter) + Checkpoint = "checkpoint" class IPAdapterModel(ModelBase): + class CheckpointConfig(ModelConfigBase): + model_format: Literal[IPAdapterModelFormat.Checkpoint] + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.IPAdapter super().__init__(model_path, base_model, model_type) @@ -31,23 +38,58 @@ class IPAdapterModel(ModelBase): if not os.path.exists(path): raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.") - raise NotImplementedError() + if os.path.isfile(path): + if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")): + return IPAdapterModelFormat.Checkpoint + + raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}") @classproperty def save_to_config(cls) -> bool: - raise NotImplementedError() + return True def get_size(self, child_type: Optional[SubModelType] = None) -> int: if child_type is not None: raise ValueError("There are no child models in an IP-Adapter model.") - raise NotImplementedError() + # TODO(ryand): Update self.model_size when the model is loaded from disk. + return self.model_size + + def _get_text_encoder_path(self) -> str: + # TODO(ryand): Move the CLIP image encoder to its own model directory. + return os.path.join(os.path.dirname(self.model_path), "image_encoder") def get_model( self, torch_dtype: Optional[torch.dtype], child_type: Optional[SubModelType] = None, - ) -> Any: + ) -> typing.Union[IPAdapter, IPAdapterPlus]: if child_type is not None: raise ValueError("There are no child models in an IP-Adapter model.") - raise NotImplementedError() + + # TODO(ryand): Update IPAdapter to accept a torch_dtype param. + + # TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a + # "plus" variant by loading the state_dict. + if "plus" in str(self.model_path): + return IPAdapterPlus( + image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu" + ) + else: + return IPAdapter( + image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu" + ) + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, + base_model: BaseModelType, + ) -> str: + format = cls.detect_format(model_path) + if format == IPAdapterModelFormat.Checkpoint: + return model_path + else: + raise ValueError(f"Unsupported format: '{format}'.") diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index a7bbfb23e2..1408c56989 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -171,8 +171,7 @@ class ControlNetData: @dataclass class IPAdapterData: - ip_adapter_model: str = Field(default=None) - image_encoder_model: str = Field(default=None) + ip_adapter_model: IPAdapter = Field(default=None) image: PIL.Image = Field(default=None) # TODO: change to polymorphic so can do different weights per step (once implemented...) # weight: Union[float, List[float]] = Field(default=1.0) @@ -417,27 +416,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return latents, attention_map_saver if ip_adapter_data is not None: - # Initialize IPAdapter - # TODO(ryand): Refactor to use model management for the IP-Adapter. - if "plus" in ip_adapter_data.ip_adapter_model: - ip_adapter = IPAdapterPlus( - self.unet, - ip_adapter_data.image_encoder_model, - ip_adapter_data.ip_adapter_model, - self.unet.device, - num_tokens=16, - ) - else: - ip_adapter = IPAdapter( - self.unet, - ip_adapter_data.image_encoder_model, - ip_adapter_data.ip_adapter_model, - self.unet.device, - ) - ip_adapter.set_scale(ip_adapter_data.weight) + if not ip_adapter_data.ip_adapter_model.is_initialized(): + # TODO(ryan): Do we need to initialize every time? How long does initialize take? + ip_adapter_data.ip_adapter_model.initialize(self.unet) + ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight) # Get image embeddings from CLIP and ImageProjModel. - image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image) + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_data.ip_adapter_model.get_image_embeds( + ip_adapter_data.image + ) conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo( image_prompt_embeds, uncond_image_prompt_embeds ) @@ -451,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): elif ip_adapter_data is not None: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. - attn_ctx = ip_adapter.apply_ip_adapter_attention() + attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention() else: attn_ctx = nullcontext()