fix: Slow loading of Loras

Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
This commit is contained in:
blessedcoolant
2023-07-05 14:37:16 +12:00
committed by psychedelicious
parent 0f0336b6ef
commit c0501ed5c2
4 changed files with 253 additions and 206 deletions

View File

@ -1,18 +1,17 @@
from __future__ import annotations
import copy
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase:
#rank: Optional[int]
@ -527,7 +526,7 @@ class ModelPatcher:
):
original_weights = dict()
try:
with torch.no_grad():
with torch.inference_mode():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@ -539,9 +538,10 @@ class ModelPatcher:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
#with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
@ -552,7 +552,7 @@ class ModelPatcher:
yield # wait for context manager exit
finally:
with torch.no_grad():
with torch.inference_mode():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)