mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reduce peak VRAM memory usage of IP adapter
This commit is contained in:
parent
6b24424727
commit
1c59fce6ad
@ -50,7 +50,6 @@ from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput,
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
@ -672,6 +671,39 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return controlnet_data
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||
if not isinstance(ip_adapter, list):
|
||||
ip_adapter = [ip_adapter]
|
||||
|
||||
if len(ip_adapter) == 0:
|
||||
return []
|
||||
|
||||
image_prompts = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
|
||||
|
||||
return image_prompts
|
||||
|
||||
def prep_ip_adapter_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -680,6 +712,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
@ -696,26 +729,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
ip_adapter_data_list = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
)
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = image_prompts.pop(0)
|
||||
|
||||
mask = single_ip_adapter.mask
|
||||
if mask is not None:
|
||||
@ -912,6 +928,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapter=self.ip_adapter)
|
||||
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
@ -975,6 +993,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
image_prompts=image_prompts,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
|
Loading…
Reference in New Issue
Block a user