Respect torch-sdp in config.yaml (#5353)

If the user specifies `torch-sdp` as the attention type in `config.yaml`, we can go ahead and use it (if available) rather than always throwing an exception.
This commit is contained in:
Jonathan 2023-12-27 23:46:28 -06:00 committed by GitHub
parent 80812cf7cd
commit 83a9e26cd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -276,7 +276,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
raise Exception("torch-sdp attention slicing not yet implemented")
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")
# the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda":
@ -284,7 +288,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.enable_xformers_memory_efficient_attention()
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enable sdp automatically
# diffusers enables sdp automatically
return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":