add support for "balanced" attention slice size

This commit is contained in:
Lincoln Stein
2023-08-17 16:11:09 -04:00
parent 23b4e1cea0
commit b69f26c85c
7 changed files with 38 additions and 12 deletions

View File

@ -33,7 +33,7 @@ from .diffusion import (
PostprocessingSettings,
BasicConditioningInfo,
)
from ..util import normalize_device
from ..util import normalize_device, auto_detect_slice_size
@dataclass
@ -296,8 +296,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return
elif config.attention_type == "sliced":
slice_size = config.attention_slice_size
if torch.backends.mps.is_available(): # doesn't auto already do this?
slice_size = "max"
if slice_size == "auto":
slice_size = auto_detect_slice_size(latents)
elif slice_size == "balanced":
slice_size = "auto"
self.enable_attention_slicing(slice_size=slice_size)
return
elif config.attention_type == "normal":