lint: correct AttributeError.name reference for Python 3.9.

This commit is contained in:
Kevin Turner 2022-11-25 14:11:19 -08:00
parent 56153c2ebf
commit 09728dd1e0

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch import torch
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
@ -255,7 +256,7 @@ def inject_attention_function(unet, context: Context):
lambda module: context.get_slicing_strategy(identifier) lambda module: context.get_slicing_strategy(identifier)
) )
except AttributeError as e: except AttributeError as e:
if e.name == 'set_attention_slice_wrangler': if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
warnings.warn(f"TODO: implement for {type(module)}") # TODO warnings.warn(f"TODO: implement for {type(module)}") # TODO
else: else:
raise raise
@ -270,7 +271,14 @@ def remove_attention_function(unet):
module.set_attention_slice_wrangler(None) module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None) module.set_slicing_strategy_getter(None)
except AttributeError as e: except AttributeError as e:
if e.name == 'set_attention_slice_wrangler': if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
warnings.warn(f"TODO: implement for {type(module)}") # TODO warnings.warn(f"TODO: implement for {type(module)}") # TODO
else: else:
raise raise
def is_attribute_error_about(error: AttributeError, attribute: str):
if hasattr(error, 'name'): # Python 3.10
return error.name == attribute
else: # Python 3.9
return attribute in str(error)