mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove need for the image_encoder param in IPAdapter.initialize().
This commit is contained in:
@ -32,6 +32,27 @@ class ImageProjModel(torch.nn.Module):
|
||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||
"""Initialize an ImageProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||
clip_extra_context_tokens (int, optional): Defaults to 4.
|
||||
|
||||
Returns:
|
||||
ImageProjModel
|
||||
"""
|
||||
cross_attention_dim = state_dict["norm.weight"].shape[0]
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
|
||||
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
@ -69,7 +90,7 @@ class IPAdapter:
|
||||
def is_initialized(self):
|
||||
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
|
||||
|
||||
def initialize(self, unet: UNet2DConditionModel, image_encoder: CLIPVisionModelWithProjection):
|
||||
def initialize(self, unet: UNet2DConditionModel):
|
||||
"""Finish the model initialization process.
|
||||
|
||||
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
|
||||
@ -82,13 +103,11 @@ class IPAdapter:
|
||||
raise Exception("IPAdapter has already been initialized.")
|
||||
|
||||
self._unet = unet
|
||||
# TODO(ryand): Eliminate the need to pass the image_encoder to initialize(). It should be possible to infer the
|
||||
# necessary information from the state_dict.
|
||||
self._image_proj_model = self._init_image_proj_model(image_encoder)
|
||||
|
||||
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
|
||||
self._attn_processors = self._prepare_attention_processors()
|
||||
|
||||
# Copy the weights from the _state_dict into the models.
|
||||
self._image_proj_model.load_state_dict(self._state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
|
||||
|
||||
@ -105,12 +124,9 @@ class IPAdapter:
|
||||
if model is not None:
|
||||
model.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=self._unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self._num_tokens,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
image_proj_model = ImageProjModel.from_state_dict(state_dict, self._num_tokens)
|
||||
image_proj_model.to(self.device, dtype=self.dtype)
|
||||
return image_proj_model
|
||||
|
||||
def _prepare_attention_processors(self):
|
||||
@ -183,15 +199,13 @@ class IPAdapter:
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
|
||||
image_proj_model = Resampler(
|
||||
dim=self._unet.config.cross_attention_dim,
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
image_proj_model = Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=self._num_tokens,
|
||||
embedding_dim=image_encoder.config.hidden_size,
|
||||
output_dim=self._unet.config.cross_attention_dim,
|
||||
ff_mult=4,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
return image_proj_model
|
||||
|
Reference in New Issue
Block a user