Handle seamless in modular denoise

This commit is contained in:
Sergey Borisov 2024-07-23 18:03:59 +03:00
parent 7c975f0d00
commit 62aa064e56
2 changed files with 80 additions and 0 deletions

View File

@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@ -833,6 +834,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:

View File

@ -0,0 +1,75 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleConv
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
class SeamlessExt(ExtensionBase):
def __init__(
self,
seamless_axes: List[str],
):
super().__init__()
self._seamless_axes = seamless_axes
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
with self.static_patch_model(
model=unet,
seamless_axes=self._seamless_axes,
):
yield
@staticmethod
@contextmanager
def static_patch_model(
model: torch.nn.Module,
seamless_axes: List[str],
):
if not seamless_axes:
yield
return
# override conv_forward
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
try:
x_mode = "circular" if "x" in seamless_axes else "constant"
y_mode = "circular" if "y" in seamless_axes else "constant"
conv_layers: List[torch.nn.Conv2d] = []
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
conv_layers.append(module)
for layer in conv_layers:
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
yield
finally:
for layer, orig_conv_forward in original_layers:
layer._conv_forward = orig_conv_forward