Respect torch-sdp in config.yaml (#5353)

If the user specifies `torch-sdp` as the attention type in `config.yaml`, we can go ahead and use it (if available) rather than always throwing an exception.
This commit is contained in:
Jonathan 2023-12-27 23:46:28 -06:00 committed by GitHub
parent 80812cf7cd
commit 83a9e26cd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -276,7 +276,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing() self.disable_attention_slicing()
return return
elif config.attention_type == "torch-sdp": elif config.attention_type == "torch-sdp":
raise Exception("torch-sdp attention slicing not yet implemented") if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")
# the remainder if this code is called when attention_type=='auto' # the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda": if self.unet.device.type == "cuda":
@ -284,7 +288,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
return return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enable sdp automatically # diffusers enables sdp automatically
return return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps": if self.unet.device.type == "cpu" or self.unet.device.type == "mps":