mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Slow loading of Loras
Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
This commit is contained in:
committed by
psychedelicious
parent
0f0336b6ef
commit
c0501ed5c2
@ -1,18 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
@ -527,7 +526,7 @@ class ModelPatcher:
|
||||
):
|
||||
original_weights = dict()
|
||||
try:
|
||||
with torch.no_grad():
|
||||
with torch.inference_mode():
|
||||
for lora, lora_weight in loras:
|
||||
#assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@ -539,9 +538,10 @@ class ModelPatcher:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
# enable autocast to calc fp16 loras on cpu
|
||||
with torch.autocast(device_type="cpu"):
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
#with torch.autocast(device_type="cpu"):
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
@ -552,7 +552,7 @@ class ModelPatcher:
|
||||
yield # wait for context manager exit
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
with torch.inference_mode():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
|
Reference in New Issue
Block a user