From d4550b30597cd0a6e082c64df629f169c5594462 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 5 Jul 2023 19:18:25 -0400 Subject: [PATCH] clean up lint errors in lora.py --- invokeai/backend/model_management/lora.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 6cfcb8dd8d..b92020189d 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -3,14 +3,12 @@ from __future__ import annotations import copy from pathlib import Path from contextlib import contextmanager -from typing import Optional, Dict, Tuple, Any - +from typing import Optional, Dict, Tuple, Any, Union, List import torch from safetensors.torch import load_file -from torch.utils.hooks import RemovableHandle from diffusers.models import UNet2DConditionModel -from transformers import CLIPTextModel +from transformers import CLIPTextModel, CLIPTokenizer from compel.embeddings_provider import BaseTextualInversionManager @@ -124,8 +122,8 @@ class LoRALayer(LoRALayerBase): def get_weight(self): if self.mid is not None: - up = self.up.reshape(up.shape[0], up.shape[1]) - down = self.down.reshape(up.shape[0], up.shape[1]) + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) @@ -411,7 +409,7 @@ class LoRAModel: #(torch.nn.Module): else: # TODO: diff/ia3/... format print( - f">> Encountered unknown lora layer module in {self.name}: {layer_key}" + f">> Encountered unknown lora layer module in {model.name}: {layer_key}" ) return