fix crash during lora application

This commit is contained in:
Lincoln Stein 2023-07-05 16:40:47 -04:00
parent e4d92da3a9
commit 685a47cc7d

View File

@ -3,15 +3,13 @@ from __future__ import annotations
import copy
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union, List
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTokenizer
class LoRALayerBase:
#rank: Optional[int]
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
else:
# TODO: diff/ia3/... format
print(
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
)
return