[Fix] Lora double hook (#3471)

Currently hooks registers multiple time for some modules.
As result - lora applies multiple time to this modules on generation and
images looks weird.
If have any other minds how to fix it better - feel free to push.
This commit is contained in:
Lincoln Stein 2023-05-29 20:53:27 -04:00 committed by GitHub
commit aa1538bd70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Dict, Tuple
import torch import torch
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
@ -166,12 +166,12 @@ class LoKRLayer:
class LoRAModuleWrapper: class LoRAModuleWrapper:
unet: UNet2DConditionModel unet: UNet2DConditionModel
text_encoder: CLIPTextModel text_encoder: CLIPTextModel
hooks: list[RemovableHandle] hooks: Dict[str, Tuple[torch.nn.Module, RemovableHandle]]
def __init__(self, unet, text_encoder): def __init__(self, unet, text_encoder):
self.unet = unet self.unet = unet
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.hooks = [] self.hooks = dict()
self.text_modules = None self.text_modules = None
self.unet_modules = None self.unet_modules = None
@ -228,7 +228,7 @@ class LoRAModuleWrapper:
wrapper = self wrapper = self
def lora_forward(module, input_h, output): def lora_forward(module, input_h, output):
if len(wrapper.loaded_loras) == 0: if len(wrapper.applied_loras) == 0:
return output return output
for lora in wrapper.applied_loras.values(): for lora in wrapper.applied_loras.values():
@ -241,11 +241,18 @@ class LoRAModuleWrapper:
return lora_forward return lora_forward
def apply_module_forward(self, module, name): def apply_module_forward(self, module, name):
handle = module.register_forward_hook(self.lora_forward_hook(name)) if name in self.hooks:
self.hooks.append(handle) registered_module, _ = self.hooks[name]
if registered_module != module:
raise Exception(f"Trying to register multiple modules to lora key: {name}")
# else it's just double hook creation - nothing to do
else:
handle = module.register_forward_hook(self.lora_forward_hook(name))
self.hooks[name] = (module, handle)
def clear_hooks(self): def clear_hooks(self):
for hook in self.hooks: for _, hook in self.hooks.values():
hook.remove() hook.remove()
self.hooks.clear() self.hooks.clear()