update to diffusers 0.15 and fix code for name changes (#3201)

- This is a port of #3184 to the main branch
This commit is contained in:
Lincoln Stein 2023-04-25 03:23:24 +01:00 committed by GitHub
commit fe12938c23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 29 deletions

View File

@ -445,8 +445,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@property @property
def _submodels(self) -> Sequence[torch.nn.Module]: def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()] submodels = []
return [m for m in values if isinstance(m, torch.nn.Module)] for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
def image_from_embeddings( def image_from_embeddings(
self, self,
@ -544,7 +551,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState( yield PipelineIntermediateState(
run_id=run_id, run_id=run_id,
step=-1, step=-1,
timestep=self.scheduler.num_train_timesteps, timestep=self.scheduler.config.num_train_timesteps,
latents=latents, latents=latents,
) )
@ -915,7 +922,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@property @property
def channels(self) -> int: def channels(self) -> int:
"""Compatible with DiffusionWrapper""" """Compatible with DiffusionWrapper"""
return self.unet.in_channels return self.unet.config.in_channels
def decode_latents(self, latents): def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method. # Explicit call to get the vae loaded, since `decode` isn't the forward method.

View File

@ -10,8 +10,7 @@ import diffusers
import psutil import psutil
import torch import torch
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.cross_attention import AttnProcessor from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn from torch import nn
from ...util import torch_dtype from ...util import torch_dtype
@ -188,7 +187,7 @@ class Context:
class InvokeAICrossAttentionMixin: class InvokeAICrossAttentionMixin:
""" """
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection. and dymamic slicing strategy selection.
""" """
@ -209,7 +208,7 @@ class InvokeAICrossAttentionMixin:
Set custom attention calculator to be called when attention is calculated Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent. which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked. `module` is the current Attention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice `suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
@ -345,11 +344,11 @@ class InvokeAICrossAttentionMixin:
def restore_default_cross_attention( def restore_default_cross_attention(
model, model,
is_running_diffusers: bool, is_running_diffusers: bool,
restore_attention_processor: Optional[AttnProcessor] = None, restore_attention_processor: Optional[AttentionProcessor] = None,
): ):
if is_running_diffusers: if is_running_diffusers:
unet = model unet = model
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) unet.set_attn_processor(restore_attention_processor or AttnProcessor())
else: else:
remove_attention_function(model) remove_attention_function(model)
@ -408,12 +407,9 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
def get_cross_attention_modules( def get_cross_attention_modules(
model, which: CrossAttentionType model, which: CrossAttentionType
) -> list[tuple[str, InvokeAICrossAttentionMixin]]: ) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
from ldm.modules.attention import CrossAttention # avoid circular import
cross_attention_class: type = ( cross_attention_class: type = (
InvokeAIDiffusersCrossAttention InvokeAIDiffusersCrossAttention
if isinstance(model, UNet2DConditionModel)
else CrossAttention
) )
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [ attention_module_tuples = [
@ -428,10 +424,10 @@ def get_cross_attention_modules(
print( print(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ f"work properly until it is fixed." + "work properly until it is fixed."
) )
return attention_module_tuples return attention_module_tuples
@ -550,7 +546,7 @@ def get_mem_free_total(device):
class InvokeAIDiffusersCrossAttention( class InvokeAIDiffusersCrossAttention(
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
): ):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -572,8 +568,8 @@ class InvokeAIDiffusersCrossAttention(
""" """
# base implementation # base implementation
class CrossAttnProcessor: class AttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
@ -601,9 +597,9 @@ class CrossAttnProcessor:
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
from diffusers.models.cross_attention import ( from diffusers.models.attention_processor import (
CrossAttention, Attention,
CrossAttnProcessor, AttnProcessor,
SlicedAttnProcessor, SlicedAttnProcessor,
) )
@ -653,7 +649,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
def __call__( def __call__(
self, self,
attn: CrossAttention, attn: Attention,
hidden_states, hidden_states,
encoder_hidden_states=None, encoder_hidden_states=None,
attention_mask=None, attention_mask=None,

View File

@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.backend.globals import Globals from invokeai.backend.globals import Globals
@ -101,7 +101,7 @@ class InvokeAIDiffuserComponent:
def override_cross_attention( def override_cross_attention(
self, conditioning: ExtraConditioningInfo, step_count: int self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]: ) -> Dict[str, AttentionProcessor]:
""" """
setup cross attention .swap control. for diffusers this replaces the attention processor, so setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later. the previous attention processor is returned so that the caller can restore it later.
@ -118,7 +118,7 @@ class InvokeAIDiffuserComponent:
) )
def restore_default_cross_attention( def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttnProcessor"] = None self, restore_attention_processor: Optional["AttentionProcessor"] = None
): ):
self.conditioning = None self.conditioning = None
self.cross_attention_control_context = None self.cross_attention_control_context = None
@ -262,7 +262,7 @@ class InvokeAIDiffuserComponent:
# TODO remove when compvis codepath support is dropped # TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None: if step_index is None and sigma is None:
raise ValueError( raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None." "Either step_index or sigma is required when doing cross attention control, but both are None."
) )
percent_through = self.estimate_percent_through(step_index, sigma) percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through return percent_through
@ -599,7 +599,6 @@ class InvokeAIDiffuserComponent:
) )
# below is fugly omg # below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c, weight in weighted_cond_list] conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list] weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2) chunk_count = ceil(len(conditionings) / 2)

View File

@ -40,7 +40,7 @@ dependencies = [
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==1.0.5", "compel==1.0.5",
"datasets", "datasets",
"diffusers[torch]==0.14", "diffusers[torch]==0.15.1",
"dnspython==2.2.1", "dnspython==2.2.1",
"einops", "einops",
"eventlet", "eventlet",