Modular backend - Seamless (#6651)

## Summary

Seamless code from #6577.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-28 13:57:38 -04:00 committed by GitHub
commit e8e24822ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 80 additions and 57 deletions

View File

@ -39,7 +39,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import ( from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, ControlNetData,
@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@ -833,6 +834,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config: if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
# context for loading additional models # context for loading additional models
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
# later should be smth like: # later should be smth like:
@ -915,7 +920,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ExitStack() as exit_stack, ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet), unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config), ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet( ModelPatcher.apply_lora_unet(
unet, unet,

View File

@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:

View File

@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
) )
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401 from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
__all__ = [ __all__ = [
"PipelineIntermediateState", "PipelineIntermediateState",
"StableDiffusionGeneratorPipeline", "StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent", "InvokeAIDiffuserComponent",
"set_seamless",
] ]

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleConv
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
class SeamlessExt(ExtensionBase):
def __init__(
self,
seamless_axes: List[str],
):
super().__init__()
self._seamless_axes = seamless_axes
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
with self.static_patch_model(
model=unet,
seamless_axes=self._seamless_axes,
):
yield
@staticmethod
@contextmanager
def static_patch_model(
model: torch.nn.Module,
seamless_axes: List[str],
):
if not seamless_axes:
yield
return
x_mode = "circular" if "x" in seamless_axes else "constant"
y_mode = "circular" if "y" in seamless_axes else "constant"
# override conv_forward
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
try:
for layer in model.modules():
if not isinstance(layer, torch.nn.Conv2d):
continue
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
yield
finally:
for layer, orig_conv_forward in original_layers:
layer._conv_forward = orig_conv_forward

View File

@ -1,51 +0,0 @@
from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
if not seamless_axes:
yield
return
# override conv_forward
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
try:
x_mode = "circular" if "x" in seamless_axes else "constant"
y_mode = "circular" if "y" in seamless_axes else "constant"
conv_layers: List[torch.nn.Conv2d] = []
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
conv_layers.append(module)
for layer in conv_layers:
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
yield
finally:
for layer, orig_conv_forward in original_layers:
layer._conv_forward = orig_conv_forward