From 685a47cc7de17a762873a3f393e8013af4158ee5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 5 Jul 2023 16:40:47 -0400 Subject: [PATCH] fix crash during lora application --- invokeai/backend/model_management/lora.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 5d27555ab3..864905d6cf 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -3,15 +3,13 @@ from __future__ import annotations import copy from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union, List import torch from compel.embeddings_provider import BaseTextualInversionManager from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file -from torch.utils.hooks import RemovableHandle -from transformers import CLIPTextModel - +from transformers import CLIPTextModel, CLIPTokenizer class LoRALayerBase: #rank: Optional[int] @@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase): def get_weight(self): if self.mid is not None: - up = self.up.reshape(up.shape[0], up.shape[1]) - down = self.down.reshape(up.shape[0], up.shape[1]) + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) @@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module): else: # TODO: diff/ia3/... format print( - f">> Encountered unknown lora layer module in {self.name}: {layer_key}" + f">> Encountered unknown lora layer module in {model.name}: {layer_key}" ) return