mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix crash during lora application
This commit is contained in:
parent
e4d92da3a9
commit
685a47cc7d
@ -3,15 +3,13 @@ from __future__ import annotations
|
|||||||
import copy
|
import copy
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torch.utils.hooks import RemovableHandle
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
from transformers import CLIPTextModel
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
#rank: Optional[int]
|
#rank: Optional[int]
|
||||||
@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
else:
|
else:
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# TODO: diff/ia3/... format
|
# TODO: diff/ia3/... format
|
||||||
print(
|
print(
|
||||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user