mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Change attention processor apply logic
This commit is contained in:
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
"""
|
||||
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
||||
"""A context manager that patches `unet` with the provided attention processor.
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model to patch.
|
||||
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||
"""
|
||||
unet_orig_processors = unet.attn_processors
|
||||
try:
|
||||
# create separate instance for each attention, to be able modify each attention separately
|
||||
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||
unet.set_attn_processor(new_attn_processors)
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
Reference in New Issue
Block a user