mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
attention maps callback stuff for diffusers
This commit is contained in:
parent
2c6db2e77c
commit
23eb80b404
@ -63,7 +63,6 @@ class Generator:
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
attention_maps_callback = None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
|
@ -7,12 +7,14 @@ from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any,
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import attention
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...models.diffusion import cross_attention_control
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
# monkeypatch diffusers CrossAttention 🙈
|
||||
# this is to make prompt2prompt and (future) attention maps work
|
||||
@ -41,6 +43,7 @@ class PipelineIntermediateState:
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
# copied from configs/stable-diffusion/v1-inference.yaml
|
||||
@ -180,6 +183,17 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
@ -255,7 +269,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
*, callback: Callable[[PipelineIntermediateState], None]=None,
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None,
|
||||
run_id=None,
|
||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
@ -273,7 +287,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
:param run_id:
|
||||
:param extra_step_kwargs:
|
||||
"""
|
||||
result_latents = self.latents_from_embeddings(
|
||||
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
run_id=run_id, callback=callback, **extra_step_kwargs
|
||||
@ -283,7 +297,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
|
||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def latents_from_embeddings(
|
||||
@ -302,13 +316,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
return infer_latents_from_embeddings(
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
**extra_step_kwargs).latents
|
||||
**extra_step_kwargs)
|
||||
return result.latents, result.attention_map_saver
|
||||
|
||||
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor,
|
||||
unconditioned_embeddings: torch.Tensor, guidance_scale: float, *,
|
||||
@ -334,6 +349,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
|
||||
attention_map_saver: AttentionMapSaver = None
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, guidance_scale,
|
||||
@ -342,9 +359,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**extra_step_kwargs)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original)
|
||||
return latents
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float,
|
||||
@ -393,7 +419,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||
|
||||
@ -412,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings,
|
||||
unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info,
|
||||
noise_func, run_id=None, callback=None, **extra_step_kwargs):
|
||||
noise_func, run_id=None, callback=None, **extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
batch_size = initial_latents.size(0)
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
@ -423,7 +449,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noised_latents = self.scheduler.add_noise(initial_latents, noise, latent_timestep)
|
||||
latents = noised_latents
|
||||
|
||||
result_latents = self.latents_from_embeddings(
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
@ -435,7 +461,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def inpaint_from_embeddings(
|
||||
@ -450,7 +476,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
**extra_step_kwargs) -> StableDiffusionPipelineOutput:
|
||||
**extra_step_kwargs) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
batch_size = 1
|
||||
@ -493,7 +519,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
|
||||
|
||||
try:
|
||||
result_latents = self.latents_from_embeddings(
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
timesteps=timesteps,
|
||||
@ -508,7 +534,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
with torch.inference_mode():
|
||||
image = self.decode_latents(result_latents)
|
||||
output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[])
|
||||
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||
return self.check_for_safety(output, dtype=text_embeddings.dtype)
|
||||
|
||||
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||
@ -523,7 +549,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||
output.images, device=self._execution_device, dtype=dtype)
|
||||
return StableDiffusionPipelineOutput(screened_images, has_nsfw_concept)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
return InvokeAIStableDiffusionPipelineOutput(screened_images,
|
||||
has_nsfw_concept,
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||
|
@ -14,7 +14,9 @@ class Img2Img(Generator):
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it.
|
||||
@ -35,7 +37,8 @@ class Img2Img(Generator):
|
||||
noise_func=self.get_noise_like,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
@ -273,6 +273,8 @@ class Inpaint(Img2Img):
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
result = pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||
|
@ -37,14 +37,12 @@ class Txt2Img(Generator):
|
||||
unconditioned_embeddings=uc,
|
||||
guidance_scale=cfg_scale,
|
||||
callback=step_callback,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
extra_conditioning_info=extra_conditioning_info
|
||||
# TODO: eta = ddim_eta,
|
||||
# TODO: threshold = threshold,
|
||||
# FIXME: Attention Maps Callback merged from main, but not hooked up
|
||||
# in diffusers branch yet. - keturn
|
||||
# attention_maps_callback = attention_maps_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
first_pass_latent_output = pipeline.latents_from_embeddings(
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=x_T,
|
||||
num_inference_steps=steps,
|
||||
text_embeddings=c,
|
||||
|
@ -442,128 +442,6 @@ def get_mem_free_total(device):
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
|
||||
def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
#attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
if dim is not None:
|
||||
print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}")
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
#print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps' or q.device.type == 'cpu':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -209,12 +209,12 @@ class KSampler(Sampler):
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
|
||||
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||
attention_maps_saver = None
|
||||
attention_map_saver = None
|
||||
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
||||
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
@ -229,8 +229,8 @@ class KSampler(Sampler):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if attention_maps_saver is not None:
|
||||
attention_maps_callback(attention_maps_saver)
|
||||
if attention_map_saver is not None:
|
||||
attention_maps_callback(attention_map_saver)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
|
Loading…
Reference in New Issue
Block a user