diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 3284f990ce..08145b1e76 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -4,6 +4,7 @@ from typing import Optional import torch + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -255,7 +256,7 @@ def inject_attention_function(unet, context: Context): lambda module: context.get_slicing_strategy(identifier) ) except AttributeError as e: - if e.name == 'set_attention_slice_wrangler': + if is_attribute_error_about(e, 'set_attention_slice_wrangler'): warnings.warn(f"TODO: implement for {type(module)}") # TODO else: raise @@ -270,7 +271,14 @@ def remove_attention_function(unet): module.set_attention_slice_wrangler(None) module.set_slicing_strategy_getter(None) except AttributeError as e: - if e.name == 'set_attention_slice_wrangler': + if is_attribute_error_about(e, 'set_attention_slice_wrangler'): warnings.warn(f"TODO: implement for {type(module)}") # TODO else: raise + + +def is_attribute_error_about(error: AttributeError, attribute: str): + if hasattr(error, 'name'): # Python 3.10 + return error.name == attribute + else: # Python 3.9 + return attribute in str(error)