fix crash during lora application

This commit is contained in:
Lincoln Stein 2023-07-05 16:40:47 -04:00
parent e4d92da3a9
commit 685a47cc7d

View File

@ -3,15 +3,13 @@ from __future__ import annotations
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union, List
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel
class LoRALayerBase: class LoRALayerBase:
#rank: Optional[int] #rank: Optional[int]
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
def get_weight(self): def get_weight(self):
if self.mid is not None: if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1]) up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else: else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
else: else:
# TODO: diff/ia3/... format # TODO: diff/ia3/... format
print( print(
f">> Encountered unknown lora layer module in {self.name}: {layer_key}" f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
) )
return return