mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
attention maps callback stuff for diffusers
This commit is contained in:
parent
2c6db2e77c
commit
23eb80b404
@ -63,7 +63,6 @@ class Generator:
|
|||||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
attention_maps_callback = None,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
self.safety_checker = safety_checker
|
self.safety_checker = safety_checker
|
||||||
|
@ -7,12 +7,14 @@ from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any,
|
|||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.models import attention
|
from diffusers.models import attention
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
from ...models.diffusion import cross_attention_control
|
from ...models.diffusion import cross_attention_control
|
||||||
|
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
# monkeypatch diffusers CrossAttention 🙈
|
# monkeypatch diffusers CrossAttention 🙈
|
||||||
# this is to make prompt2prompt and (future) attention maps work
|
# this is to make prompt2prompt and (future) attention maps work
|
||||||
@ -41,6 +43,7 @@ class PipelineIntermediateState:
|
|||||||
timestep: int
|
timestep: int
|
||||||
latents: torch.Tensor
|
latents: torch.Tensor
|
||||||
predicted_original: Optional[torch.Tensor] = None
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
|
|
||||||
|
|
||||||
# copied from configs/stable-diffusion/v1-inference.yaml
|
# copied from configs/stable-diffusion/v1-inference.yaml
|
||||||
@ -180,6 +183,17 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||||
|
r"""
|
||||||
|
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||||
|
after generation completes. Optional.
|
||||||
|
"""
|
||||||
|
attention_map_saver: Optional[AttentionMapSaver]
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
@ -255,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
@ -273,7 +287,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
:param run_id:
|
:param run_id:
|
||||||
:param extra_step_kwargs:
|
:param extra_step_kwargs:
|
||||||
"""
|
"""
|
||||||
result_latents = self.latents_from_embeddings(
|
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
run_id=run_id, callback=callback, **extra_step_kwargs
|
run_id=run_id, callback=callback, **extra_step_kwargs
|
||||||
@ -283,7 +297,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image = self.decode_latents(result_latents)
|
image = self.decode_latents(result_latents)
|
||||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
|
||||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
def latents_from_embeddings(
|
def latents_from_embeddings(
|
||||||
@ -302,13 +316,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||||
timesteps = self.scheduler.timesteps
|
timesteps = self.scheduler.timesteps
|
||||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||||
return infer_latents_from_embeddings(
|
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||||
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
**extra_step_kwargs).latents
|
**extra_step_kwargs)
|
||||||
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
|
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
|
||||||
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
|
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
|
||||||
@ -334,6 +349,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
batched_t = torch.full((batch_size,), timesteps[0],
|
batched_t = torch.full((batch_size,), timesteps[0],
|
||||||
dtype=timesteps.dtype, device=self.unet.device)
|
dtype=timesteps.dtype, device=self.unet.device)
|
||||||
|
|
||||||
|
attention_map_saver: AttentionMapSaver = None
|
||||||
|
self.invokeai_diffuser.remove_attention_map_saving()
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(batched_t, latents, guidance_scale,
|
step_output = self.step(batched_t, latents, guidance_scale,
|
||||||
@ -342,9 +359,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
**extra_step_kwargs)
|
**extra_step_kwargs)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
|
|
||||||
|
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||||
|
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||||
|
attention_map_token_ids = range(1, eos_token_index)
|
||||||
|
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||||
|
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||||
predicted_original=predicted_original)
|
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||||
return latents
|
|
||||||
|
self.invokeai_diffuser.remove_attention_map_saving()
|
||||||
|
return latents, attention_map_saver
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
||||||
@ -393,7 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
noise_func=None,
|
noise_func=None,
|
||||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||||
|
|
||||||
@ -412,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
||||||
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
||||||
noise_func, run_id=None, callback=None, **extra_step_kwargs):
|
noise_func, run_id=None, callback=None, **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
device = self.unet.device
|
device = self.unet.device
|
||||||
batch_size = initial_latents.size(0)
|
batch_size = initial_latents.size(0)
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
@ -423,7 +449,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep)
|
noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep)
|
||||||
latents = noised_latents
|
latents = noised_latents
|
||||||
|
|
||||||
result_latents = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@ -435,7 +461,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image = self.decode_latents(result_latents)
|
image = self.decode_latents(result_latents)
|
||||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
def inpaint_from_embeddings(
|
def inpaint_from_embeddings(
|
||||||
@ -450,7 +476,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
noise_func=None,
|
noise_func=None,
|
||||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
device = self.unet.device
|
device = self.unet.device
|
||||||
latents_dtype = self.unet.dtype
|
latents_dtype = self.unet.dtype
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
@ -493,7 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
|
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_latents = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@ -508,7 +534,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image = self.decode_latents(result_latents)
|
image = self.decode_latents(result_latents)
|
||||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||||
@ -523,7 +549,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||||
output.images, device=self._execution_device, dtype=dtype)
|
output.images, device=self._execution_device, dtype=dtype)
|
||||||
return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept)
|
screened_attention_map_saver = None
|
||||||
|
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||||
|
screened_attention_map_saver = output.attention_map_saver
|
||||||
|
return InvokeAIStableDiffusionPipelineOutput(screened_images,
|
||||||
|
has_nsfw_concept,
|
||||||
|
# block the attention maps if NSFW content is detected
|
||||||
|
attention_map_saver=screened_attention_map_saver)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||||
|
@ -14,7 +14,9 @@ class Img2Img(Generator):
|
|||||||
self.init_latent = None # by get_noise()
|
self.init_latent = None # by get_noise()
|
||||||
|
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
|
||||||
|
attention_maps_callback=None,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it.
|
Return value depends on the seed at the time you call it.
|
||||||
@ -35,7 +37,8 @@ class Img2Img(Generator):
|
|||||||
noise_func=self.get_noise_like,
|
noise_func=self.get_noise_like,
|
||||||
callback=step_callback
|
callback=step_callback
|
||||||
)
|
)
|
||||||
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
@ -273,6 +273,8 @@ class Inpaint(Img2Img):
|
|||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
result = pipeline.numpy_to_pil(pipeline_output.images)[0]
|
result = pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||||
|
@ -37,14 +37,12 @@ class Txt2Img(Generator):
|
|||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
guidance_scale=cfg_scale,
|
guidance_scale=cfg_scale,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info
|
||||||
# TODO: eta = ddim_eta,
|
# TODO: eta = ddim_eta,
|
||||||
# TODO: threshold = threshold,
|
# TODO: threshold = threshold,
|
||||||
# FIXME: Attention Maps Callback merged from main, but not hooked up
|
|
||||||
# in diffusers branch yet. - keturn
|
|
||||||
# attention_maps_callback = attention_maps_callback,
|
|
||||||
)
|
)
|
||||||
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
|
|||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
|
|
||||||
first_pass_latent_output = pipeline.latents_from_embeddings(
|
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||||
latents=x_T,
|
latents=x_T,
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
|
@ -442,128 +442,6 @@ def get_mem_free_total(device):
|
|||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
|
||||||
"""
|
|
||||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
|
||||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
|
||||||
and dymamic slicing strategy selection.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
||||||
self.attention_slice_wrangler = None
|
|
||||||
self.slicing_strategy_getter = None
|
|
||||||
|
|
||||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
|
||||||
'''
|
|
||||||
Set custom attention calculator to be called when attention is calculated
|
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
|
||||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
|
||||||
|
|
||||||
Pass None to use the default attention calculation.
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
self.attention_slice_wrangler = wrangler
|
|
||||||
|
|
||||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
|
||||||
self.slicing_strategy_getter = getter
|
|
||||||
|
|
||||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
|
||||||
# calculate attention scores
|
|
||||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
|
||||||
if dim is not None:
|
|
||||||
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
|
|
||||||
attention_scores = torch.baddbmm(
|
|
||||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
|
||||||
query,
|
|
||||||
key.transpose(-1, -2),
|
|
||||||
beta=0,
|
|
||||||
alpha=self.scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate attention slice by taking the best scores for each latent pixel
|
|
||||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
|
||||||
attention_slice_wrangler = self.attention_slice_wrangler
|
|
||||||
if attention_slice_wrangler is not None:
|
|
||||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
|
||||||
else:
|
|
||||||
attention_slice = default_attention_slice
|
|
||||||
|
|
||||||
hidden_states = torch.bmm(attention_slice, value)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[0], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v):
|
|
||||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(self, q, k, v):
|
|
||||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
else:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
||||||
if size_mb <= max_tensor_mb:
|
|
||||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
|
||||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
||||||
if div <= q.shape[0]:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
|
||||||
|
|
||||||
def einsum_op_cuda(self, q, k, v):
|
|
||||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
|
||||||
slicing_strategy_getter = self.slicing_strategy_getter
|
|
||||||
if slicing_strategy_getter is not None:
|
|
||||||
(dim, slice_size) = slicing_strategy_getter(self)
|
|
||||||
if dim is not None:
|
|
||||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
|
||||||
if dim == 0:
|
|
||||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
|
||||||
elif dim == 1:
|
|
||||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
|
||||||
|
|
||||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
|
||||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
||||||
|
|
||||||
|
|
||||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
|
||||||
if q.device.type == 'cuda':
|
|
||||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
|
||||||
return self.einsum_op_cuda(q, k, v)
|
|
||||||
|
|
||||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
|
||||||
if self.mem_total_gb >= 32:
|
|
||||||
return self.einsum_op_mps_v1(q, k, v)
|
|
||||||
return self.einsum_op_mps_v2(q, k, v)
|
|
||||||
|
|
||||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
||||||
# Tested on i7 with 8MB L3 cache.
|
|
||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -209,12 +209,12 @@ class KSampler(Sampler):
|
|||||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||||
|
|
||||||
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||||
attention_maps_saver = None
|
attention_map_saver = None
|
||||||
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||||
attention_map_token_ids = range(1, eos_token_index)
|
attention_map_token_ids = range(1, eos_token_index)
|
||||||
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
@ -229,8 +229,8 @@ class KSampler(Sampler):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if attention_maps_saver is not None:
|
if attention_map_saver is not None:
|
||||||
attention_maps_callback(attention_maps_saver)
|
attention_maps_callback(attention_map_saver)
|
||||||
return sampling_result
|
return sampling_result
|
||||||
|
|
||||||
# this code will support inpainting if and when ksampler API modified or
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
|
Loading…
Reference in New Issue
Block a user