attention maps callback stuff for diffusers

This commit is contained in:
Damian Stewart 2022-12-14 21:04:55 +01:00
parent 2c6db2e77c
commit 23eb80b404
8 changed files with 63 additions and 151 deletions

View File

@ -63,7 +63,6 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
attention_maps_callback = None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker self.safety_checker = safety_checker

View File

@ -7,12 +7,14 @@ from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any,
import PIL.Image import PIL.Image
import einops import einops
import numpy as np
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import attention from diffusers.models import attention
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from ...models.diffusion import cross_attention_control from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
# monkeypatch diffusers CrossAttention 🙈 # monkeypatch diffusers CrossAttention 🙈
# this is to make prompt2prompt and (future) attention maps work # this is to make prompt2prompt and (future) attention maps work
@ -41,6 +43,7 @@ class PipelineIntermediateState:
timestep: int timestep: int
latents: torch.Tensor latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None predicted_original: Optional[torch.Tensor] = None
attention_map_saver: Optional[AttentionMapSaver] = None
# copied from configs/stable-diffusion/v1-inference.yaml # copied from configs/stable-diffusion/v1-inference.yaml
@ -180,6 +183,17 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?") raise AssertionError("why was that an empty generator?")
return result return result
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
@ -255,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*, callback: Callable[[PipelineIntermediateState], None]=None, *, callback: Callable[[PipelineIntermediateState], None]=None,
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
run_id=None, run_id=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput: **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -273,7 +287,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
:param run_id: :param run_id:
:param extra_step_kwargs: :param extra_step_kwargs:
""" """
result_latents = self.latents_from_embeddings( result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
run_id=run_id, callback=callback, **extra_step_kwargs run_id=run_id, callback=callback, **extra_step_kwargs
@ -283,7 +297,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode(): with torch.inference_mode():
image = self.decode_latents(result_latents) image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
return self.check_for_safety(output, dtype=text_embeddings.dtype) return self.check_for_safety(output, dtype=text_embeddings.dtype)
def latents_from_embeddings( def latents_from_embeddings(
@ -302,13 +316,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
return infer_latents_from_embeddings( result: PipelineIntermediateState = infer_latents_from_embeddings(
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale, latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
run_id=run_id, run_id=run_id,
callback=callback, callback=callback,
**extra_step_kwargs).latents **extra_step_kwargs)
return result.latents, result.attention_map_saver
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor, def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *, unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
@ -334,6 +349,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batched_t = torch.full((batch_size,), timesteps[0], batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device) dtype=timesteps.dtype, device=self.unet.device)
attention_map_saver: AttentionMapSaver = None
self.invokeai_diffuser.remove_attention_map_saving()
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t) batched_t.fill_(t)
step_output = self.step(batched_t, latents, guidance_scale, step_output = self.step(batched_t, latents, guidance_scale,
@ -342,9 +359,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**extra_step_kwargs) **extra_step_kwargs)
latents = step_output.prev_sample latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None) predicted_original = getattr(step_output, 'pred_original_sample', None)
if i == len(timesteps)-1 and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original) predicted_original=predicted_original, attention_map_saver=attention_map_saver)
return latents
self.invokeai_diffuser.remove_attention_map_saving()
return latents, attention_map_saver
@torch.inference_mode() @torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
@ -393,7 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None, run_id=None,
noise_func=None, noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput: **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
@ -412,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings, def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info, unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
noise_func, run_id=None, callback=None, **extra_step_kwargs): noise_func, run_id=None, callback=None, **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device device = self.unet.device
batch_size = initial_latents.size(0) batch_size = initial_latents.size(0)
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
@ -423,7 +449,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep) noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep)
latents = noised_latents latents = noised_latents
result_latents = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps, timesteps=timesteps,
@ -435,7 +461,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode(): with torch.inference_mode():
image = self.decode_latents(result_latents) image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
return self.check_for_safety(output, dtype=text_embeddings.dtype) return self.check_for_safety(output, dtype=text_embeddings.dtype)
def inpaint_from_embeddings( def inpaint_from_embeddings(
@ -450,7 +476,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
run_id=None, run_id=None,
noise_func=None, noise_func=None,
**extra_step_kwargs) -> StableDiffusionPipelineOutput: **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device device = self.unet.device
latents_dtype = self.unet.dtype latents_dtype = self.unet.dtype
batch_size = 1 batch_size = 1
@ -493,7 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise)) guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
try: try:
result_latents = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
timesteps=timesteps, timesteps=timesteps,
@ -508,7 +534,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode(): with torch.inference_mode():
image = self.decode_latents(result_latents) image = self.decode_latents(result_latents)
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
return self.check_for_safety(output, dtype=text_embeddings.dtype) return self.check_for_safety(output, dtype=text_embeddings.dtype)
def non_noised_latents_from_image(self, init_image, *, device, dtype): def non_noised_latents_from_image(self, init_image, *, device, dtype):
@ -523,7 +549,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode(): with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker( screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, device=self._execution_device, dtype=dtype) output.images, device=self._execution_device, dtype=dtype)
return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept) screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver)
@torch.inference_mode() @torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None): def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):

View File

@ -14,7 +14,9 @@ class Img2Img(Generator):
self.init_latent = None # by get_noise() self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs): conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
**kwargs):
""" """
Returns a function returning an image derived from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it. Return value depends on the seed at the time you call it.
@ -35,7 +37,8 @@ class Img2Img(Generator):
noise_func=self.get_noise_like, noise_func=self.get_noise_like,
callback=step_callback callback=step_callback
) )
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0] return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image return make_image

View File

@ -273,6 +273,8 @@ class Inpaint(Img2Img):
callback=step_callback, callback=step_callback,
) )
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
result = pipeline.numpy_to_pil(pipeline_output.images)[0] result = pipeline.numpy_to_pil(pipeline_output.images)[0]
# Seam paint if this is our first pass (seam_size set to 0 during seam painting) # Seam paint if this is our first pass (seam_size set to 0 during seam painting)

View File

@ -37,14 +37,12 @@ class Txt2Img(Generator):
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
guidance_scale=cfg_scale, guidance_scale=cfg_scale,
callback=step_callback, callback=step_callback,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info
# TODO: eta = ddim_eta, # TODO: eta = ddim_eta,
# TODO: threshold = threshold, # TODO: threshold = threshold,
# FIXME: Attention Maps Callback merged from main, but not hooked up
# in diffusers branch yet. - keturn
# attention_maps_callback = attention_maps_callback,
) )
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0] return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image return make_image

View File

@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
def make_image(x_T): def make_image(x_T):
first_pass_latent_output = pipeline.latents_from_embeddings( first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=x_T, latents=x_T,
num_inference_steps=steps, num_inference_steps=steps,
text_embeddings=c, text_embeddings=c,

View File

@ -442,128 +442,6 @@ def get_mem_free_total(device):
return mem_free_total return mem_free_total
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
def __init__(self):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
self.attention_slice_wrangler = None
self.slicing_strategy_getter = None
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
self.slicing_strategy_getter = getter
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
# calculate attention scores
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
if dim is not None:
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# calculate attention slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
attention_slice_wrangler = self.attention_slice_wrangler
if attention_slice_wrangler is not None:
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
hidden_states = torch.bmm(attention_slice, value)
return hidden_states
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
slicing_strategy_getter = self.slicing_strategy_getter
if slicing_strategy_getter is not None:
(dim, slice_size) = slicing_strategy_getter(self)
if dim is not None:
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
if dim == 0:
return self.einsum_op_slice_dim0(q, k, v, slice_size)
elif dim == 1:
return self.einsum_op_slice_dim1(q, k, v, slice_size)
# fallback for when there is no saved strategy, or saved strategy does not slice
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def get_invokeai_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
return self.einsum_op_cuda(q, k, v)
if q.device.type == 'mps' or q.device.type == 'cpu':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs): def __init__(self, **kwargs):

View File

@ -209,12 +209,12 @@ class KSampler(Sampler):
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
# setup attention maps saving. checks for None are because there are multiple code paths to get here. # setup attention maps saving. checks for None are because there are multiple code paths to get here.
attention_maps_saver = None attention_map_saver = None
if attention_maps_callback is not None and extra_conditioning_info is not None: if attention_maps_callback is not None and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1 eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index) attention_map_token_ids = range(1, eos_token_index)
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:]) attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver) model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
@ -229,8 +229,8 @@ class KSampler(Sampler):
), ),
None, None,
) )
if attention_maps_saver is not None: if attention_map_saver is not None:
attention_maps_callback(attention_maps_saver) attention_maps_callback(attention_map_saver)
return sampling_result return sampling_result
# this code will support inpainting if and when ksampler API modified or # this code will support inpainting if and when ksampler API modified or