Re-factor IPAdapter to patch UNet in a context manager.

This commit is contained in:
Ryan Dick
2023-09-08 15:39:22 -04:00
parent d669f0855d
commit 91596d9527
4 changed files with 76 additions and 188 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import dataclasses
import inspect
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, Union
@ -419,21 +420,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if ip_adapter_data is not None:
# Initialize IPAdapter
# FIXME:
# WARNING!
# IPAdapter constructor modifies UNet model in-place
# Adds additional cross-attention layers to UNet model for image embedding
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
# and how to undo if ip_adapter_image is removed
# Should reimplement to use existing model management context etc.
#
# TODO(ryand): Refactor to use model management for the IP-Adapter.
if "sdxl" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterXL(
self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
self.unet, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
)
elif "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
@ -441,7 +435,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
else:
ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
@ -454,13 +448,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
image_prompt_embeds, uncond_image_prompt_embeds
)
# TODO(ryand): Apply IP-Adapter or custom attention control
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
)
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention()
else:
attn_ctx = nullcontext()
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(