mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Re-factor IPAdapter to patch UNet in a context manager.
This commit is contained in:
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
@ -419,21 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if ip_adapter_data is not None:
|
||||
# Initialize IPAdapter
|
||||
# FIXME:
|
||||
# WARNING!
|
||||
# IPAdapter constructor modifies UNet model in-place
|
||||
# Adds additional cross-attention layers to UNet model for image embedding
|
||||
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
||||
# and how to undo if ip_adapter_image is removed
|
||||
# Should reimplement to use existing model management context etc.
|
||||
#
|
||||
# TODO(ryand): Refactor to use model management for the IP-Adapter.
|
||||
if "sdxl" in ip_adapter_data.ip_adapter_model:
|
||||
ip_adapter = IPAdapterXL(
|
||||
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
|
||||
self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
|
||||
)
|
||||
elif "plus" in ip_adapter_data.ip_adapter_model:
|
||||
ip_adapter = IPAdapterPlus(
|
||||
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
||||
self.unet,
|
||||
ip_adapter_data.image_encoder_model,
|
||||
ip_adapter_data.ip_adapter_model,
|
||||
self.unet.device,
|
||||
@ -441,7 +435,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
else:
|
||||
ip_adapter = IPAdapter(
|
||||
self, # IPAdapter first arg is StableDiffusionPipeline
|
||||
self.unet,
|
||||
ip_adapter_data.image_encoder_model,
|
||||
ip_adapter_data.ip_adapter_model,
|
||||
self.unet.device,
|
||||
@ -454,13 +448,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
)
|
||||
|
||||
# TODO(ryand): Apply IP-Adapter or custom attention control
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=conditioning_data.extra,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
)
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
attn_ctx = ip_adapter.apply_ip_adapter_attention()
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
|
Reference in New Issue
Block a user