mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[Fix] Lora double hook (#3471)
Currently hooks registers multiple time for some modules. As result - lora applies multiple time to this modules on generation and images looks weird. If have any other minds how to fix it better - feel free to push.
This commit is contained in:
commit
aa1538bd70
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
@ -166,12 +166,12 @@ class LoKRLayer:
|
|||||||
class LoRAModuleWrapper:
|
class LoRAModuleWrapper:
|
||||||
unet: UNet2DConditionModel
|
unet: UNet2DConditionModel
|
||||||
text_encoder: CLIPTextModel
|
text_encoder: CLIPTextModel
|
||||||
hooks: list[RemovableHandle]
|
hooks: Dict[str, Tuple[torch.nn.Module, RemovableHandle]]
|
||||||
|
|
||||||
def __init__(self, unet, text_encoder):
|
def __init__(self, unet, text_encoder):
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.hooks = []
|
self.hooks = dict()
|
||||||
self.text_modules = None
|
self.text_modules = None
|
||||||
self.unet_modules = None
|
self.unet_modules = None
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ class LoRAModuleWrapper:
|
|||||||
wrapper = self
|
wrapper = self
|
||||||
|
|
||||||
def lora_forward(module, input_h, output):
|
def lora_forward(module, input_h, output):
|
||||||
if len(wrapper.loaded_loras) == 0:
|
if len(wrapper.applied_loras) == 0:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
for lora in wrapper.applied_loras.values():
|
for lora in wrapper.applied_loras.values():
|
||||||
@ -241,11 +241,18 @@ class LoRAModuleWrapper:
|
|||||||
return lora_forward
|
return lora_forward
|
||||||
|
|
||||||
def apply_module_forward(self, module, name):
|
def apply_module_forward(self, module, name):
|
||||||
handle = module.register_forward_hook(self.lora_forward_hook(name))
|
if name in self.hooks:
|
||||||
self.hooks.append(handle)
|
registered_module, _ = self.hooks[name]
|
||||||
|
if registered_module != module:
|
||||||
|
raise Exception(f"Trying to register multiple modules to lora key: {name}")
|
||||||
|
# else it's just double hook creation - nothing to do
|
||||||
|
|
||||||
|
else:
|
||||||
|
handle = module.register_forward_hook(self.lora_forward_hook(name))
|
||||||
|
self.hooks[name] = (module, handle)
|
||||||
|
|
||||||
def clear_hooks(self):
|
def clear_hooks(self):
|
||||||
for hook in self.hooks:
|
for _, hook in self.hooks.values():
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
self.hooks.clear()
|
self.hooks.clear()
|
||||||
|
Loading…
Reference in New Issue
Block a user