Base code from draft PR

This commit is contained in:
Sergey Borisov
2024-07-12 20:31:26 +03:00
parent 712cf00a82
commit 9cc852cf7f
8 changed files with 781 additions and 11 deletions

View File

@ -0,0 +1,9 @@
"""
Initialization file for the invokeai.backend.stable_diffusion.extensions package
"""
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
__all__ = [
"ExtensionBase",
]

View File

@ -0,0 +1,58 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@dataclass
class InjectionInfo:
type: str
name: str
order: Optional[str]
function: Callable
def modifier(name: str, order: str = "any"):
def _decorator(func):
func.__inj_info__ = {
"type": "modifier",
"name": name,
"order": order,
}
return func
return _decorator
def override(name: str):
def _decorator(func):
func.__inj_info__ = {
"type": "override",
"name": name,
"order": None,
}
return func
return _decorator
class ExtensionBase:
def __init__(self, priority: int):
self.priority = priority
self.injections: List[InjectionInfo] = []
for func_name in dir(self):
func = getattr(self, func_name)
if not callable(func) or not hasattr(func, "__inj_info__"):
continue
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
@contextmanager
def patch_attention_processor(self, attention_processor_cls: object):
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
yield None