InvokeAI/invokeai/backend/stable_diffusion/extensions/base.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

51 lines
1.2 KiB
Python
Raw Normal View History

2024-07-16 17:03:29 +00:00
from __future__ import annotations
2024-07-12 17:31:26 +00:00
from contextlib import contextmanager
from dataclasses import dataclass
2024-07-16 17:03:29 +00:00
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
2024-07-12 17:31:26 +00:00
import torch
from diffusers import UNet2DConditionModel
2024-07-16 17:03:29 +00:00
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
2024-07-12 17:31:26 +00:00
@dataclass
class InjectionInfo:
type: str
name: str
order: Optional[int]
2024-07-12 17:31:26 +00:00
function: Callable
def callback(name: str, order: int = 0):
2024-07-12 17:31:26 +00:00
def _decorator(func):
func.__inj_info__ = {
"type": "callback",
2024-07-12 17:31:26 +00:00
"name": name,
"order": order,
}
return func
return _decorator
class ExtensionBase:
2024-07-12 21:44:21 +00:00
def __init__(self):
2024-07-12 17:31:26 +00:00
self.injections: List[InjectionInfo] = []
for func_name in dir(self):
func = getattr(self, func_name)
if not callable(func) or not hasattr(func, "__inj_info__"):
continue
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
@contextmanager
2024-07-16 17:03:29 +00:00
def patch_extension(self, context: DenoiseContext):
2024-07-12 17:31:26 +00:00
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
yield None