mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
update to diffusers 0.15 and fix code for name changes (#3201)
- This is a port of #3184 to the main branch
This commit is contained in:
commit
fe12938c23
@ -445,8 +445,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@property
|
@property
|
||||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
values = [getattr(self, name) for name in module_names.keys()]
|
submodels = []
|
||||||
return [m for m in values if isinstance(m, torch.nn.Module)]
|
for name in module_names.keys():
|
||||||
|
if hasattr(self, name):
|
||||||
|
value = getattr(self, name)
|
||||||
|
else:
|
||||||
|
value = getattr(self.config, name)
|
||||||
|
if isinstance(value, torch.nn.Module):
|
||||||
|
submodels.append(value)
|
||||||
|
return submodels
|
||||||
|
|
||||||
def image_from_embeddings(
|
def image_from_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -544,7 +551,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
step=-1,
|
step=-1,
|
||||||
timestep=self.scheduler.num_train_timesteps,
|
timestep=self.scheduler.config.num_train_timesteps,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -915,7 +922,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@property
|
@property
|
||||||
def channels(self) -> int:
|
def channels(self) -> int:
|
||||||
"""Compatible with DiffusionWrapper"""
|
"""Compatible with DiffusionWrapper"""
|
||||||
return self.unet.in_channels
|
return self.unet.config.in_channels
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents):
|
||||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||||
|
@ -10,8 +10,7 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
@ -188,7 +187,7 @@ class Context:
|
|||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
class InvokeAICrossAttentionMixin:
|
||||||
"""
|
"""
|
||||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
and dymamic slicing strategy selection.
|
and dymamic slicing strategy selection.
|
||||||
"""
|
"""
|
||||||
@ -209,7 +208,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
Set custom attention calculator to be called when attention is calculated
|
Set custom attention calculator to be called when attention is calculated
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
`module` is the current Attention module for which the callback is being invoked.
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
@ -345,11 +344,11 @@ class InvokeAICrossAttentionMixin:
|
|||||||
def restore_default_cross_attention(
|
def restore_default_cross_attention(
|
||||||
model,
|
model,
|
||||||
is_running_diffusers: bool,
|
is_running_diffusers: bool,
|
||||||
restore_attention_processor: Optional[AttnProcessor] = None,
|
restore_attention_processor: Optional[AttentionProcessor] = None,
|
||||||
):
|
):
|
||||||
if is_running_diffusers:
|
if is_running_diffusers:
|
||||||
unet = model
|
unet = model
|
||||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
@ -408,12 +407,9 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
def get_cross_attention_modules(
|
def get_cross_attention_modules(
|
||||||
model, which: CrossAttentionType
|
model, which: CrossAttentionType
|
||||||
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
|
||||||
|
|
||||||
cross_attention_class: type = (
|
cross_attention_class: type = (
|
||||||
InvokeAIDiffusersCrossAttention
|
InvokeAIDiffusersCrossAttention
|
||||||
if isinstance(model, UNet2DConditionModel)
|
|
||||||
else CrossAttention
|
|
||||||
)
|
)
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [
|
attention_module_tuples = [
|
||||||
@ -428,10 +424,10 @@ def get_cross_attention_modules(
|
|||||||
print(
|
print(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||||
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||||
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||||
+ f"work properly until it is fixed."
|
+ "work properly until it is fixed."
|
||||||
)
|
)
|
||||||
return attention_module_tuples
|
return attention_module_tuples
|
||||||
|
|
||||||
@ -550,7 +546,7 @@ def get_mem_free_total(device):
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(
|
class InvokeAIDiffusersCrossAttention(
|
||||||
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
|
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
|
||||||
):
|
):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -572,8 +568,8 @@ class InvokeAIDiffusersCrossAttention(
|
|||||||
"""
|
"""
|
||||||
# base implementation
|
# base implementation
|
||||||
|
|
||||||
class CrossAttnProcessor:
|
class AttnProcessor:
|
||||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
@ -601,9 +597,9 @@ class CrossAttnProcessor:
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.cross_attention import (
|
from diffusers.models.attention_processor import (
|
||||||
CrossAttention,
|
Attention,
|
||||||
CrossAttnProcessor,
|
AttnProcessor,
|
||||||
SlicedAttnProcessor,
|
SlicedAttnProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -653,7 +649,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: CrossAttention,
|
attn: Attention,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
@ -101,7 +101,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def override_cross_attention(
|
def override_cross_attention(
|
||||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||||
) -> Dict[str, AttnProcessor]:
|
) -> Dict[str, AttentionProcessor]:
|
||||||
"""
|
"""
|
||||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
the previous attention processor is returned so that the caller can restore it later.
|
the previous attention processor is returned so that the caller can restore it later.
|
||||||
@ -118,7 +118,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
def restore_default_cross_attention(
|
||||||
self, restore_attention_processor: Optional["AttnProcessor"] = None
|
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||||
):
|
):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
@ -262,7 +262,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO remove when compvis codepath support is dropped
|
# TODO remove when compvis codepath support is dropped
|
||||||
if step_index is None and sigma is None:
|
if step_index is None and sigma is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Either step_index or sigma is required when doing cross attention control, but both are None."
|
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||||
)
|
)
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
return percent_through
|
return percent_through
|
||||||
@ -599,7 +599,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# below is fugly omg
|
# below is fugly omg
|
||||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||||
chunk_count = ceil(len(conditionings) / 2)
|
chunk_count = ceil(len(conditionings) / 2)
|
||||||
|
@ -40,7 +40,7 @@ dependencies = [
|
|||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==1.0.5",
|
"compel==1.0.5",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]==0.14",
|
"diffusers[torch]==0.15.1",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
"einops",
|
"einops",
|
||||||
"eventlet",
|
"eventlet",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user