Change attention processor apply logic

This commit is contained in:
Sergey Borisov
2024-07-16 20:03:29 +03:00
parent 608cbe3f5c
commit cec345cb5c
5 changed files with 36 additions and 34 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.
Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
try:
# create separate instance for each attention, to be able modify each attention separately
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
unet.set_attn_processor(new_attn_processors)
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key