Fixed dict error

This commit is contained in:
Kent Keirsey 2023-08-27 14:21:56 -04:00
parent c6d00387a7
commit 0d2e194213

View File

@ -23,31 +23,31 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager @contextmanager
def set_unet_seamless(model: UNet2DModel, seamless: bool, seamless_axes): def set_unet_seamless(model: UNet2DModel, seamless_axes):
try: try:
to_restore = dict() to_restore = []
if seamless:
for m in model.modules(): for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {} m.asymmetric_padding_mode = {}
m.asymmetric_padding = {} m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = ( m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1], m._reversed_padding_repeated_twice[1],
0, 0,
0, 0,
) )
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = ( m.asymmetric_padding["y"] = (
0, 0,
0, 0,
m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3], m._reversed_padding_repeated_twice[3],
) )
to_restore.append((m, m._conv_forward)) to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield yield