Apply lora by patching lora instead of hooks

This commit is contained in:
Sergey Borisov 2023-06-26 03:57:33 +03:00
parent 1ba94a92b3
commit 5cebf67ee4
4 changed files with 52 additions and 38 deletions

View File

@ -65,23 +65,20 @@ class CompelInvocation(BaseInvocation):
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
text_encoder_info as text_encoder:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
).context.model
)
except Exception:
#print(e)

View File

@ -362,8 +362,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -374,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -438,8 +437,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
**self.unet.unet.dict(),
)
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -468,7 +466,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(

View File

@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
@ -195,8 +199,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
@ -206,8 +211,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)

View File

@ -70,7 +70,7 @@ class LoRALayerBase:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight(module)
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
@ -81,7 +81,7 @@ class LoRALayerBase:
**extra_args,
) * multiplier * scale
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
@ -122,7 +122,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
@ -185,7 +185,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -271,7 +271,7 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
@ -286,7 +286,7 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
weight = torch.kron(w1, w2)
return weight
@ -471,7 +471,7 @@ class ModelPatcher:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = module_key.rstrip(".")
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@ -525,23 +525,36 @@ class ModelPatcher:
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
hooks = dict()
original_weights = dict()
try:
with torch.no_grad():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in hooks:
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
yield # wait for context manager exit
finally:
for module_key, hook in hooks.items():
hook.remove()
hooks.clear()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
@classmethod
@ -591,7 +604,7 @@ class ModelPatcher:
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
ti_tokens.append(token_id)
if len(ti_tokens) > 1: