From 1c8991a3df372da3d51e0daeec7777d98c1cebeb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 13 Sep 2023 19:10:02 -0400 Subject: [PATCH] Use CLIPVisionModel under model management for IP-Adapter. --- invokeai/app/invocations/latent.py | 39 +++++++++++++++++-- invokeai/backend/ip_adapter/ip_adapter.py | 35 ++++++++--------- .../model_management/models/clip_vision.py | 2 +- .../model_management/models/ip_adapter.py | 12 +----- .../stable_diffusion/diffusers_pipeline.py | 17 +------- 5 files changed, 57 insertions(+), 48 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 8660f6c353..7c13eacdbb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -8,6 +8,7 @@ import numpy as np import torch import torchvision.transforms as T from diffusers.image_processor import VaeImageProcessor +from diffusers.models import UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -32,9 +33,11 @@ from invokeai.app.invocations.primitives import ( ) from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ConditioningData, + IPAdapterConditioningInfo, ) from ...backend.model_management.lora import ModelPatcher @@ -403,14 +406,25 @@ class DenoiseLatentsInvocation(BaseInvocation): self, context: InvocationContext, ip_adapter: Optional[IPAdapterField], + conditioning_data: ConditioningData, + unet: UNet2DConditionModel, exit_stack: ExitStack, ) -> Optional[IPAdapterData]: + """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings + to the `conditioning_data` (in-place). + """ if ip_adapter is None: return None - input_image = context.services.images.get_pil_image(ip_adapter.image.image_name) + image_encoder_model_info = context.services.model_manager.get_model( + # TODO(ryand): Get this model_name from the IPAdapterField. + model_name="ip_adapter_clip_vision", + model_type=ModelType.CLIPVision, + base_model=ip_adapter.ip_adapter_model.base_model, + context=context, + ) - ip_adapter_model = exit_stack.enter_context( + ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( context.services.model_manager.get_model( model_name=ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, @@ -418,9 +432,26 @@ class DenoiseLatentsInvocation(BaseInvocation): context=context, ) ) + + input_image = context.services.images.get_pil_image(ip_adapter.image.image_name) + + # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other + # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. + with image_encoder_model_info as image_encoder_model: + if not ip_adapter_model.is_initialized(): + # TODO(ryan): Do we need to initialize every time? How long does initialize take? + ip_adapter_model.initialize(unet, image_encoder_model) + + # Get image embeddings from CLIP and ImageProjModel. + image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( + input_image, image_encoder_model + ) + conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo( + image_prompt_embeds, uncond_image_prompt_embeds + ) + return IPAdapterData( ip_adapter_model=ip_adapter_model, - image=input_image, weight=ip_adapter.weight, ) @@ -552,6 +583,8 @@ class DenoiseLatentsInvocation(BaseInvocation): ip_adapter_data = self.prep_ip_adapter_data( context=context, ip_adapter=self.ip_adapter, + conditioning_data=conditioning_data, + unet=unet, exit_stack=exit_stack, ) diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 7f95750aaf..8d70986594 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -46,7 +46,6 @@ class IPAdapter: def __init__( self, - image_encoder_path: str, ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16, @@ -55,13 +54,9 @@ class IPAdapter: self.device = device self.dtype = dtype - self._image_encoder_path = image_encoder_path self._ip_adapter_ckpt_path = ip_adapter_ckpt_path self._num_tokens = num_tokens - self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to( - self.device, dtype=self.dtype - ) self._clip_image_processor = CLIPImageProcessor() # Fields to be initialized later in initialize(). @@ -74,7 +69,7 @@ class IPAdapter: def is_initialized(self): return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None - def initialize(self, unet: UNet2DConditionModel): + def initialize(self, unet: UNet2DConditionModel, image_encoder: CLIPVisionModelWithProjection): """Finish the model initialization process. HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires @@ -87,7 +82,9 @@ class IPAdapter: raise Exception("IPAdapter has already been initialized.") self._unet = unet - self._image_proj_model = self._init_image_proj_model() + # TODO(ryand): Eliminate the need to pass the image_encoder to initialize(). It should be possible to infer the + # necessary information from the state_dict. + self._image_proj_model = self._init_image_proj_model(image_encoder) self._attn_processors = self._prepare_attention_processors() # Copy the weights from the _state_dict into the models. @@ -102,16 +99,16 @@ class IPAdapter: if dtype is not None: self.dtype = dtype - for model in [self._image_encoder, self._image_proj_model, self._attn_processors]: + for model in [self._image_proj_model, self._attn_processors]: # If this is called before initialize(), then some models will still be None. We just update the non-None # models. if model is not None: model.to(device=self.device, dtype=self.dtype) - def _init_image_proj_model(self): + def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection): image_proj_model = ImageProjModel( cross_attention_dim=self._unet.config.cross_attention_dim, - clip_embeddings_dim=self._image_encoder.config.projection_dim, + clip_embeddings_dim=image_encoder.config.projection_dim, clip_extra_context_tokens=self._num_tokens, ).to(self.device, dtype=self.dtype) return image_proj_model @@ -162,14 +159,14 @@ class IPAdapter: self._unet.set_attn_processor(orig_attn_processors) @torch.inference_mode() - def get_image_embeds(self, pil_image): + def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): if not self.is_initialized(): raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().") if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image_embeds = self._image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds + clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds image_prompt_embeds = self._image_proj_model(clip_image_embeds) uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds @@ -186,21 +183,21 @@ class IPAdapter: class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" - def _init_image_proj_model(self): + def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection): image_proj_model = Resampler( dim=self._unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self._num_tokens, - embedding_dim=self._image_encoder.config.hidden_size, + embedding_dim=image_encoder.config.hidden_size, output_dim=self._unet.config.cross_attention_dim, ff_mult=4, ).to(self.device, dtype=self.dtype) return image_proj_model @torch.inference_mode() - def get_image_embeds(self, pil_image): + def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): if not self.is_initialized(): raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().") @@ -208,10 +205,10 @@ class IPAdapterPlus(IPAdapter): pil_image = [pil_image] clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=self.dtype) - clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self._image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self._image_encoder( - torch.zeros_like(clip_image), output_hidden_states=True - ).hidden_states[-2] + uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[ + -2 + ] uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds diff --git a/invokeai/backend/model_management/models/clip_vision.py b/invokeai/backend/model_management/models/clip_vision.py index 7df3119f9c..2276c6beed 100644 --- a/invokeai/backend/model_management/models/clip_vision.py +++ b/invokeai/backend/model_management/models/clip_vision.py @@ -60,7 +60,7 @@ class CLIPVisionModel(ModelBase): if child_type is not None: raise ValueError("There are no child models in a CLIP Vision model.") - model = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path, torch_dtype=torch_dtype) + model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype) # Calculate a more accurate model size. self.model_size = calc_model_size_by_data(model) diff --git a/invokeai/backend/model_management/models/ip_adapter.py b/invokeai/backend/model_management/models/ip_adapter.py index 01a97b6dfc..737d5da046 100644 --- a/invokeai/backend/model_management/models/ip_adapter.py +++ b/invokeai/backend/model_management/models/ip_adapter.py @@ -55,10 +55,6 @@ class IPAdapterModel(ModelBase): # TODO(ryand): Update self.model_size when the model is loaded from disk. return self.model_size - def _get_text_encoder_path(self) -> str: - # TODO(ryand): Move the CLIP image encoder to its own model directory. - return os.path.join(os.path.dirname(self.model_path), "image_encoder") - def get_model( self, torch_dtype: Optional[torch.dtype], @@ -72,13 +68,9 @@ class IPAdapterModel(ModelBase): # TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a # "plus" variant by loading the state_dict. if "plus" in str(self.model_path): - return IPAdapterPlus( - image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu" - ) + return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu") else: - return IPAdapter( - image_encoder_path=self._get_text_encoder_path(), ip_adapter_ckpt_path=self.model_path, device="cpu" - ) + return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu") @classmethod def convert_if_required( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 1408c56989..51df61ce03 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -172,7 +172,6 @@ class ControlNetData: @dataclass class IPAdapterData: ip_adapter_model: IPAdapter = Field(default=None) - image: PIL.Image = Field(default=None) # TODO: change to polymorphic so can do different weights per step (once implemented...) # weight: Union[float, List[float]] = Field(default=1.0) weight: float = Field(default=1.0) @@ -415,20 +414,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents, attention_map_saver - if ip_adapter_data is not None: - if not ip_adapter_data.ip_adapter_model.is_initialized(): - # TODO(ryan): Do we need to initialize every time? How long does initialize take? - ip_adapter_data.ip_adapter_model.initialize(self.unet) - ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight) - - # Get image embeddings from CLIP and ImageProjModel. - image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_data.ip_adapter_model.get_image_embeds( - ip_adapter_data.image - ) - conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo( - image_prompt_embeds, uncond_image_prompt_embeds - ) - if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, @@ -438,6 +423,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): elif ip_adapter_data is not None: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. + + ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight) attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention() else: attn_ctx = nullcontext()