mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove need for the image_encoder param in IPAdapter.initialize().
This commit is contained in:
parent
cc8b7a74da
commit
d114d0ba95
@ -434,13 +434,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||
|
||||
if not ip_adapter_model.is_initialized():
|
||||
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
|
||||
ip_adapter_model.initialize(unet)
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
if not ip_adapter_model.is_initialized():
|
||||
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
|
||||
ip_adapter_model.initialize(unet, image_encoder_model)
|
||||
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
|
@ -32,6 +32,27 @@ class ImageProjModel(torch.nn.Module):
|
||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||
"""Initialize an ImageProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||
clip_extra_context_tokens (int, optional): Defaults to 4.
|
||||
|
||||
Returns:
|
||||
ImageProjModel
|
||||
"""
|
||||
cross_attention_dim = state_dict["norm.weight"].shape[0]
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
|
||||
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
@ -69,7 +90,7 @@ class IPAdapter:
|
||||
def is_initialized(self):
|
||||
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
|
||||
|
||||
def initialize(self, unet: UNet2DConditionModel, image_encoder: CLIPVisionModelWithProjection):
|
||||
def initialize(self, unet: UNet2DConditionModel):
|
||||
"""Finish the model initialization process.
|
||||
|
||||
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
|
||||
@ -82,13 +103,11 @@ class IPAdapter:
|
||||
raise Exception("IPAdapter has already been initialized.")
|
||||
|
||||
self._unet = unet
|
||||
# TODO(ryand): Eliminate the need to pass the image_encoder to initialize(). It should be possible to infer the
|
||||
# necessary information from the state_dict.
|
||||
self._image_proj_model = self._init_image_proj_model(image_encoder)
|
||||
|
||||
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
|
||||
self._attn_processors = self._prepare_attention_processors()
|
||||
|
||||
# Copy the weights from the _state_dict into the models.
|
||||
self._image_proj_model.load_state_dict(self._state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
|
||||
|
||||
@ -105,12 +124,9 @@ class IPAdapter:
|
||||
if model is not None:
|
||||
model.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=self._unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self._num_tokens,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
image_proj_model = ImageProjModel.from_state_dict(state_dict, self._num_tokens)
|
||||
image_proj_model.to(self.device, dtype=self.dtype)
|
||||
return image_proj_model
|
||||
|
||||
def _prepare_attention_processors(self):
|
||||
@ -183,15 +199,13 @@ class IPAdapter:
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
|
||||
image_proj_model = Resampler(
|
||||
dim=self._unet.config.cross_attention_dim,
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
image_proj_model = Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=self._num_tokens,
|
||||
embedding_dim=image_encoder.config.hidden_size,
|
||||
output_dim=self._unet.config.cross_attention_dim,
|
||||
ff_mult=4,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
return image_proj_model
|
||||
|
@ -109,6 +109,42 @@ class Resampler(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
||||
"""A convenience function that initializes a Resampler from a state_dict.
|
||||
|
||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
|
||||
possible if needed in the future.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict to load.
|
||||
depth (int, optional):
|
||||
dim_head (int, optional):
|
||||
heads (int, optional):
|
||||
ff_mult (int, optional):
|
||||
|
||||
Returns:
|
||||
Resampler
|
||||
"""
|
||||
dim = state_dict["latents"].shape[2]
|
||||
num_queries = state_dict["latents"].shape[1]
|
||||
embedding_dim = state_dict["proj_in.weight"].shape[-1]
|
||||
output_dim = state_dict["norm_out.weight"].shape[0]
|
||||
|
||||
model = cls(
|
||||
dim=dim,
|
||||
depth=depth,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=output_dim,
|
||||
ff_mult=ff_mult,
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user