mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Avoid double hook registration in lora
This commit is contained in:
parent
f3b2e02921
commit
9e87a080a8
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
@ -166,12 +166,12 @@ class LoKRLayer:
|
||||
class LoRAModuleWrapper:
|
||||
unet: UNet2DConditionModel
|
||||
text_encoder: CLIPTextModel
|
||||
hooks: list[RemovableHandle]
|
||||
hooks: Dict[str, Tuple[torch.nn.Module, RemovableHandle]]
|
||||
|
||||
def __init__(self, unet, text_encoder):
|
||||
self.unet = unet
|
||||
self.text_encoder = text_encoder
|
||||
self.hooks = []
|
||||
self.hooks = dict()
|
||||
self.text_modules = None
|
||||
self.unet_modules = None
|
||||
|
||||
@ -228,7 +228,7 @@ class LoRAModuleWrapper:
|
||||
wrapper = self
|
||||
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(wrapper.loaded_loras) == 0:
|
||||
if len(wrapper.applied_loras) == 0:
|
||||
return output
|
||||
|
||||
for lora in wrapper.applied_loras.values():
|
||||
@ -241,11 +241,18 @@ class LoRAModuleWrapper:
|
||||
return lora_forward
|
||||
|
||||
def apply_module_forward(self, module, name):
|
||||
handle = module.register_forward_hook(self.lora_forward_hook(name))
|
||||
self.hooks.append(handle)
|
||||
if name in self.hooks:
|
||||
registered_module, _ = self.hooks[name]
|
||||
if registered_module != module:
|
||||
raise Exception(f"Trying to register multiple modules to lora key: {name}")
|
||||
# else it's just double hook creation - nothing to do
|
||||
|
||||
else:
|
||||
handle = module.register_forward_hook(self.lora_forward_hook(name))
|
||||
self.hooks[name] = (module, handle)
|
||||
|
||||
def clear_hooks(self):
|
||||
for hook in self.hooks:
|
||||
for _, hook in self.hooks.values():
|
||||
hook.remove()
|
||||
|
||||
self.hooks.clear()
|
||||
|
Loading…
Reference in New Issue
Block a user