Move monkeypatch for diffusers/torch bug to hotfixes.py

This commit is contained in:
Sergey Borisov 2023-08-28 18:29:49 +03:00
parent 3efb1f6f17
commit bb085c5fba
2 changed files with 17 additions and 11 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import Union from typing import Union, List
import diffusers import diffusers
import torch.nn as nn import torch.nn as nn
from diffusers.models import UNet2DModel, AutoencoderKL from diffusers.models import UNet2DModel, AutoencoderKL
@ -24,11 +24,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
) )
ModelType = Union[UNet2DModel, AutoencoderKL]
@contextmanager @contextmanager
def set_seamless(model: ModelType, seamless_axes): def set_seamless(model: Union[UNet2DModel, AutoencoderKL], seamless_axes: List[str]):
try: try:
to_restore = [] to_restore = []
@ -53,8 +50,6 @@ def set_seamless(model: ModelType, seamless_axes):
to_restore.append((m, m._conv_forward)) to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
if isinstance(m, diffusers.models.lora.LoRACompatibleConv) and m.lora_layer is None:
m.forward = nn.Conv2d.forward.__get__(m, nn.Conv2d)
yield yield
@ -65,7 +60,3 @@ def set_seamless(model: ModelType, seamless_axes):
del m.asymmetric_padding_mode del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"): if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding del m.asymmetric_padding
if isinstance(m, diffusers.models.lora.LoRACompatibleConv):
m.forward = diffusers.models.lora.LoRACompatibleConv.forward.__get__(
m, diffusers.models.lora.LoRACompatibleConv
)

View File

@ -761,3 +761,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
diffusers.ControlNetModel = ControlNetModel diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel
# patch LoRACompatibleConv to use original Conv2D forward function
# this needed to make work seamless patch
# NOTE: with this patch, torch.compile crashes on 2.0 torch(already fixed in nightly)
# https://github.com/huggingface/diffusers/pull/4315
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/lora.py#L96C18-L96C18
def new_LoRACompatibleConv_forward(self, x):
if self.lora_layer is None:
return super(diffusers.models.lora.LoRACompatibleConv, self).forward(x)
else:
return super(diffusers.models.lora.LoRACompatibleConv, self).forward(x) + self.lora_layer(x)
diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward