Eliminate the need for IPAdapter.initialize().

This commit is contained in:
Ryan Dick 2023-09-14 15:02:59 -04:00
parent d114d0ba95
commit 781e8521d5
3 changed files with 44 additions and 76 deletions

View File

@ -434,10 +434,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
if not ip_adapter_model.is_initialized():
# TODO(ryan): Do we need to initialize every time? How long does initialize take?
ip_adapter_model.initialize(unet)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:

View File

@ -80,70 +80,44 @@ class IPAdapter:
self._clip_image_processor = CLIPImageProcessor()
# Fields to be initialized later in initialize().
self._unet = None
self._image_proj_model = None
self._attn_processors = None
self._state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
def is_initialized(self):
return self._unet is not None and self._image_proj_model is not None and self._attn_processors is not None
def initialize(self, unet: UNet2DConditionModel):
"""Finish the model initialization process.
HACK: This is separate from __init__ for compatibility with the model manager. The full initialization requires
access to the UNet model to be patched, which can not easily be passed to __init__ by the model manager.
Args:
unet (UNet2DConditionModel): The UNet whose attention blocks will be patched by this IP-Adapter.
"""
if self.is_initialized():
raise Exception("IPAdapter has already been initialized.")
self._unet = unet
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
self._attn_processors = self._prepare_attention_processors()
# Copy the weights from the _state_dict into the models.
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._state_dict = None
# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype
for model in [self._image_proj_model, self._attn_processors]:
# If this is called before initialize(), then some models will still be None. We just update the non-None
# models.
if model is not None:
model.to(device=self.device, dtype=self.dtype)
self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors).to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self, state_dict):
image_proj_model = ImageProjModel.from_state_dict(state_dict, self._num_tokens)
image_proj_model.to(self.device, dtype=self.dtype)
return image_proj_model
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self):
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in self._unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self._unet.config.block_out_channels[-1]
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self._unet.config.block_out_channels[block_id]
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
@ -152,33 +126,43 @@ class IPAdapter:
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
return attn_procs
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
@contextmanager
def apply_ip_adapter_attention(self):
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: int):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.apply_ip_adapter_attention().")
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
orig_attn_processors = self._unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because set_attn_processor(...) actually
# pops elements from the passed dict.
# Set scale.
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
self._unet.set_attn_processor(ip_adapter_attn_processors)
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
self._unet.set_attn_processor(orig_attn_processors)
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().")
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
@ -187,20 +171,12 @@ class IPAdapter:
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.set_scale().")
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict):
image_proj_model = Resampler.from_state_dict(
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
dim_head=64,
@ -208,13 +184,9 @@ class IPAdapterPlus(IPAdapter):
num_queries=self._num_tokens,
ff_mult=4,
).to(self.device, dtype=self.dtype)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if not self.is_initialized():
raise Exception("Call IPAdapter.initialize() before calling IPAdapter.get_image_embeds().")
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values

View File

@ -423,9 +423,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention()
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model, scale=ip_adapter_data.weight
)
else:
attn_ctx = nullcontext()