[Fix] Lora double hook (#3471)

Currently hooks registers multiple time for some modules.
As result - lora applies multiple time to this modules on generation and
images looks weird.
If have any other minds how to fix it better - feel free to push.
This commit is contained in:
Lincoln Stein 2023-05-29 20:53:27 -04:00 committed by GitHub
commit aa1538bd70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Optional
from typing import Optional, Dict, Tuple
import torch
from diffusers.models import UNet2DConditionModel
@ -166,12 +166,12 @@ class LoKRLayer:
class LoRAModuleWrapper:
unet: UNet2DConditionModel
text_encoder: CLIPTextModel
hooks: list[RemovableHandle]
hooks: Dict[str, Tuple[torch.nn.Module, RemovableHandle]]
def __init__(self, unet, text_encoder):
self.unet = unet
self.text_encoder = text_encoder
self.hooks = []
self.hooks = dict()
self.text_modules = None
self.unet_modules = None
@ -228,7 +228,7 @@ class LoRAModuleWrapper:
wrapper = self
def lora_forward(module, input_h, output):
if len(wrapper.loaded_loras) == 0:
if len(wrapper.applied_loras) == 0:
return output
for lora in wrapper.applied_loras.values():
@ -241,11 +241,18 @@ class LoRAModuleWrapper:
return lora_forward
def apply_module_forward(self, module, name):
handle = module.register_forward_hook(self.lora_forward_hook(name))
self.hooks.append(handle)
if name in self.hooks:
registered_module, _ = self.hooks[name]
if registered_module != module:
raise Exception(f"Trying to register multiple modules to lora key: {name}")
# else it's just double hook creation - nothing to do
else:
handle = module.register_forward_hook(self.lora_forward_hook(name))
self.hooks[name] = (module, handle)
def clear_hooks(self):
for hook in self.hooks:
for _, hook in self.hooks.values():
hook.remove()
self.hooks.clear()