Remove need for the image_encoder param in IPAdapter.initialize().

This commit is contained in:
Ryan Dick 2023-09-14 14:14:35 -04:00
parent cc8b7a74da
commit d114d0ba95
3 changed files with 70 additions and 20 deletions

View File

@ -434,13 +434,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
if not ip_adapter_model.is_initialized():
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
ip_adapter_model.initialize(unet)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
if not ip_adapter_model.is_initialized():
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
ip_adapter_model.initialize(unet, image_encoder_model)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model

View File

@ -32,6 +32,27 @@ class ImageProjModel(torch.nn.Module):
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Args:
state_dict (dict[torch.Tensor]): The state_dict of model weights.
clip_extra_context_tokens (int, optional): Defaults to 4.
Returns:
ImageProjModel
"""
cross_attention_dim = state_dict["norm.weight"].shape[0]
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
@ -69,7 +90,7 @@ class IPAdapter:
def is_initialized(self):
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
def initialize(self, unet: UNet2DConditionModel, image_encoder: CLIPVisionModelWithProjection):
def initialize(self, unet: UNet2DConditionModel):
"""Finish the model initialization process.
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
@ -82,13 +103,11 @@ class IPAdapter:
raise Exception("IPAdapter has already been initialized.")
self._unet = unet
# TODO(ryand): Eliminate the need to pass the image_encoder to initialize(). It should be possible to infer the
# necessary information from the state_dict.
self._image_proj_model = self._init_image_proj_model(image_encoder)
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
self._attn_processors = self._prepare_attention_processors()
# Copy the weights from the _state_dict into the models.
self._image_proj_model.load_state_dict(self._state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
@ -105,12 +124,9 @@ class IPAdapter:
if model is not None:
model.to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
image_proj_model = ImageProjModel(
cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=image_encoder.config.projection_dim,
clip_extra_context_tokens=self._num_tokens,
).to(self.device, dtype=self.dtype)
def _init_image_proj_model(self, state_dict):
image_proj_model = ImageProjModel.from_state_dict(state_dict, self._num_tokens)
image_proj_model.to(self.device, dtype=self.dtype)
return image_proj_model
def _prepare_attention_processors(self):
@ -183,15 +199,13 @@ class IPAdapter:
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
image_proj_model = Resampler(
dim=self._unet.config.cross_attention_dim,
def _init_image_proj_model(self, state_dict):
image_proj_model = Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
dim_head=64,
heads=12,
num_queries=self._num_tokens,
embedding_dim=image_encoder.config.hidden_size,
output_dim=self._unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=self.dtype)
return image_proj_model

View File

@ -109,6 +109,42 @@ class Resampler(nn.Module):
)
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
possible if needed in the future.
Args:
state_dict (dict[torch.Tensor]): The state_dict to load.
depth (int, optional):
dim_head (int, optional):
heads (int, optional):
ff_mult (int, optional):
Returns:
Resampler
"""
dim = state_dict["latents"].shape[2]
num_queries = state_dict["latents"].shape[1]
embedding_dim = state_dict["proj_in.weight"].shape[-1]
output_dim = state_dict["norm_out.weight"].shape[0]
model = cls(
dim=dim,
depth=depth,
dim_head=dim_head,
heads=heads,
num_queries=num_queries,
embedding_dim=embedding_dim,
output_dim=output_dim,
ff_mult=ff_mult,
)
model.load_state_dict(state_dict)
return model
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)