diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index e4dca8c9a4..9d62ce58a3 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -21,7 +21,7 @@ from .resampler import Resampler class ImageProjModel(torch.nn.Module): - """Projection Model""" + """Image Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() @@ -43,31 +43,38 @@ class ImageProjModel(torch.nn.Module): class IPAdapter: """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" - def __init__(self, unet: UNet2DConditionModel, image_encoder_path, ip_ckpt, device, num_tokens=4): + def __init__( + self, + unet: UNet2DConditionModel, + image_encoder_path: str, + ip_adapter_ckpt_path: str, + device: torch.device, + num_tokens: int = 4, + ): self._unet = unet - self.device = device - self.image_encoder_path = image_encoder_path - self.ip_ckpt = ip_ckpt - self.num_tokens = num_tokens + self._device = device + self._image_encoder_path = image_encoder_path + self._ip_adapter_ckpt_path = ip_adapter_ckpt_path + self._num_tokens = num_tokens self._attn_processors = self._prepare_attention_processors() # load image encoder - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( - self.device, dtype=torch.float16 + self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to( + self._device, dtype=torch.float16 ) - self.clip_image_processor = CLIPImageProcessor() + self._clip_image_processor = CLIPImageProcessor() # image proj model - self.image_proj_model = self.init_proj() + self._image_proj_model = self._init_image_proj_model() - self.load_ip_adapter() + self._load_weights() - def init_proj(self): + def _init_image_proj_model(self): image_proj_model = ImageProjModel( cross_attention_dim=self._unet.config.cross_attention_dim, - clip_embeddings_dim=self.image_encoder.config.projection_dim, - clip_extra_context_tokens=self.num_tokens, - ).to(self.device, dtype=torch.float16) + clip_embeddings_dim=self._image_encoder.config.projection_dim, + clip_extra_context_tokens=self._num_tokens, + ).to(self._device, dtype=torch.float16) return image_proj_model def _prepare_attention_processors(self): @@ -92,7 +99,7 @@ class IPAdapter: hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, - ).to(self.device, dtype=torch.float16) + ).to(self._device, dtype=torch.float16) return attn_procs @contextmanager @@ -109,9 +116,9 @@ class IPAdapter: finally: self._unet.set_attn_processor(orig_attn_processors) - def load_ip_adapter(self): - state_dict = torch.load(self.ip_ckpt, map_location="cpu") - self.image_proj_model.load_state_dict(state_dict["image_proj"]) + def _load_weights(self): + state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu") + self._image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self._attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) @@ -119,10 +126,10 @@ class IPAdapter: def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds + image_prompt_embeds = self._image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds def set_scale(self, scale): @@ -134,29 +141,29 @@ class IPAdapter: class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" - def init_proj(self): + def _init_image_proj_model(self): image_proj_model = Resampler( dim=self._unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, + num_queries=self._num_tokens, + embedding_dim=self._image_encoder.config.hidden_size, output_dim=self._unet.config.cross_attention_dim, ff_mult=4, - ).to(self.device, dtype=torch.float16) + ).to(self._device, dtype=torch.float16) return image_proj_model @torch.inference_mode() def get_image_embeds(self, pil_image): if isinstance(pil_image, Image.Image): pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder( + clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self._device, dtype=torch.float16) + clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self._image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self._image_encoder( torch.zeros_like(clip_image), output_hidden_states=True ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds