mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove need for the image_encoder param in IPAdapter.initialize().
This commit is contained in:
parent
cc8b7a74da
commit
d114d0ba95
@ -434,13 +434,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||||
|
|
||||||
|
if not ip_adapter_model.is_initialized():
|
||||||
|
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
|
||||||
|
ip_adapter_model.initialize(unet)
|
||||||
|
|
||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
with image_encoder_model_info as image_encoder_model:
|
with image_encoder_model_info as image_encoder_model:
|
||||||
if not ip_adapter_model.is_initialized():
|
|
||||||
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
|
|
||||||
ip_adapter_model.initialize(unet, image_encoder_model)
|
|
||||||
|
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||||
input_image, image_encoder_model
|
input_image, image_encoder_model
|
||||||
|
@ -32,6 +32,27 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||||
|
"""Initialize an ImageProjModel from a state_dict.
|
||||||
|
|
||||||
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||||
|
clip_extra_context_tokens (int, optional): Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageProjModel
|
||||||
|
"""
|
||||||
|
cross_attention_dim = state_dict["norm.weight"].shape[0]
|
||||||
|
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||||
|
|
||||||
|
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
embeds = image_embeds
|
embeds = image_embeds
|
||||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||||
@ -69,7 +90,7 @@ class IPAdapter:
|
|||||||
def is_initialized(self):
|
def is_initialized(self):
|
||||||
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
|
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
|
||||||
|
|
||||||
def initialize(self, unet: UNet2DConditionModel, image_encoder: CLIPVisionModelWithProjection):
|
def initialize(self, unet: UNet2DConditionModel):
|
||||||
"""Finish the model initialization process.
|
"""Finish the model initialization process.
|
||||||
|
|
||||||
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
|
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
|
||||||
@ -82,13 +103,11 @@ class IPAdapter:
|
|||||||
raise Exception("IPAdapter has already been initialized.")
|
raise Exception("IPAdapter has already been initialized.")
|
||||||
|
|
||||||
self._unet = unet
|
self._unet = unet
|
||||||
# TODO(ryand): Eliminate the need to pass the image_encoder to initialize(). It should be possible to infer the
|
|
||||||
# necessary information from the state_dict.
|
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
|
||||||
self._image_proj_model = self._init_image_proj_model(image_encoder)
|
|
||||||
self._attn_processors = self._prepare_attention_processors()
|
self._attn_processors = self._prepare_attention_processors()
|
||||||
|
|
||||||
# Copy the weights from the _state_dict into the models.
|
# Copy the weights from the _state_dict into the models.
|
||||||
self._image_proj_model.load_state_dict(self._state_dict["image_proj"])
|
|
||||||
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||||
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
|
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
|
||||||
|
|
||||||
@ -105,12 +124,9 @@ class IPAdapter:
|
|||||||
if model is not None:
|
if model is not None:
|
||||||
model.to(device=self.device, dtype=self.dtype)
|
model.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
|
def _init_image_proj_model(self, state_dict):
|
||||||
image_proj_model = ImageProjModel(
|
image_proj_model = ImageProjModel.from_state_dict(state_dict, self._num_tokens)
|
||||||
cross_attention_dim=self._unet.config.cross_attention_dim,
|
image_proj_model.to(self.device, dtype=self.dtype)
|
||||||
clip_embeddings_dim=image_encoder.config.projection_dim,
|
|
||||||
clip_extra_context_tokens=self._num_tokens,
|
|
||||||
).to(self.device, dtype=self.dtype)
|
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
|
||||||
def _prepare_attention_processors(self):
|
def _prepare_attention_processors(self):
|
||||||
@ -183,15 +199,13 @@ class IPAdapter:
|
|||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
def _init_image_proj_model(self, image_encoder: CLIPVisionModelWithProjection):
|
def _init_image_proj_model(self, state_dict):
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler.from_state_dict(
|
||||||
dim=self._unet.config.cross_attention_dim,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
dim_head=64,
|
dim_head=64,
|
||||||
heads=12,
|
heads=12,
|
||||||
num_queries=self._num_tokens,
|
num_queries=self._num_tokens,
|
||||||
embedding_dim=image_encoder.config.hidden_size,
|
|
||||||
output_dim=self._unet.config.cross_attention_dim,
|
|
||||||
ff_mult=4,
|
ff_mult=4,
|
||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
@ -109,6 +109,42 @@ class Resampler(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
||||||
|
"""A convenience function that initializes a Resampler from a state_dict.
|
||||||
|
|
||||||
|
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||||
|
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
|
||||||
|
possible if needed in the future.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[torch.Tensor]): The state_dict to load.
|
||||||
|
depth (int, optional):
|
||||||
|
dim_head (int, optional):
|
||||||
|
heads (int, optional):
|
||||||
|
ff_mult (int, optional):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampler
|
||||||
|
"""
|
||||||
|
dim = state_dict["latents"].shape[2]
|
||||||
|
num_queries = state_dict["latents"].shape[1]
|
||||||
|
embedding_dim = state_dict["proj_in.weight"].shape[-1]
|
||||||
|
output_dim = state_dict["norm_out.weight"].shape[0]
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
dim=dim,
|
||||||
|
depth=depth,
|
||||||
|
dim_head=dim_head,
|
||||||
|
heads=heads,
|
||||||
|
num_queries=num_queries,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
ff_mult=ff_mult,
|
||||||
|
)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user