Apply lora by model patching (#3583)

Rewrite lora to be applied by model patching as it gives us benefits:
1) On model execution calculates result only on model weight, while with
hooks we need to calculate on model and each lora
2) As lora now patched in model weights, there no need to store lora in
vram

Results:
Speed:
| loras count | hook | patch |
| --- | --- | --- |
| 0 | ~4.92 it/s | ~4.92 it/s |
| 1 | ~3.51 it/s | ~4.89 it/s |
| 2 | ~2.76 it/s | ~4.92 it/s |

VRAM:
| loras count | hook | patch |
| --- | --- | --- |
| 0 | ~3.6 gb | ~3.6 gb |
| 1 | ~4.0 gb | ~3.6 gb |
| 2 | ~4.4 gb | ~3.7 gb |

As based on #3547 wait to merge.
This commit is contained in:
Lincoln Stein 2023-06-28 15:48:57 -04:00 committed by GitHub
commit 8a90e51408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 40 deletions

View File

@ -65,23 +65,20 @@ class CompelInvocation(BaseInvocation):
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
text_encoder_info as text_encoder:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)

View File

@ -285,8 +285,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.dispatch_progress(context, source_node_id, state)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -297,7 +296,7 @@ class TextToLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
@ -361,8 +360,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
**self.unet.unet.dict(),
)
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:
scheduler = get_scheduler(
context=context,
@ -391,7 +389,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(

View File

@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
@ -195,8 +199,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
@ -206,8 +211,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)

View File

@ -70,7 +70,7 @@ class LoRALayerBase:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight(module)
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
@ -81,7 +81,7 @@ class LoRALayerBase:
**extra_args,
) * multiplier * scale
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
@ -122,7 +122,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
@ -166,7 +166,7 @@ class LoHALayer(LoRALayerBase):
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
@ -185,7 +185,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -239,7 +239,7 @@ class LoKRLayer(LoRALayerBase):
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
@ -271,7 +271,7 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
@ -286,7 +286,7 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
weight = torch.kron(w1, w2)
return weight
@ -471,7 +471,7 @@ class ModelPatcher:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = module_key.rstrip(".")
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@ -525,23 +525,36 @@ class ModelPatcher:
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
hooks = dict()
original_weights = dict()
try:
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
with torch.no_grad():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in hooks:
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
yield # wait for context manager exit
finally:
for module_key, hook in hooks.items():
hook.remove()
hooks.clear()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
@classmethod
@ -591,7 +604,7 @@ class ModelPatcher:
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
ti_tokens.append(token_id)
if len(ti_tokens) > 1: