diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py
index e8e2b3f51f..3da27988dc 100644
--- a/invokeai/backend/model_management/lora.py
+++ b/invokeai/backend/model_management/lora.py
@@ -143,7 +143,7 @@ class ModelPatcher:
                         # with torch.autocast(device_type="cpu"):
                         layer.to(dtype=torch.float32)
                         layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
-                        layer_weight = layer.get_weight() * lora_weight * layer_scale
+                        layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
 
                         if module.weight.shape != layer_weight.shape:
                             # TODO: debug on lycoris
@@ -361,7 +361,8 @@ class ONNXModelPatcher:
 
                     layer.to(dtype=torch.float32)
                     layer_key = layer_key.replace(prefix, "")
-                    layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
+                    # TODO: rewrite to pass original tensor weight(required by ia3)
+                    layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
                     if layer_key is blended_loras:
                         blended_loras[layer_key] += layer_weight
                     else:
diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py
index 1983c05503..c3f25e6852 100644
--- a/invokeai/backend/model_management/models/lora.py
+++ b/invokeai/backend/model_management/models/lora.py
@@ -122,41 +122,7 @@ class LoRALayerBase:
         self.rank = None  # set in layer implementation
         self.layer_key = layer_key
 
-    def forward(
-        self,
-        module: torch.nn.Module,
-        input_h: Any,  # for real looks like Tuple[torch.nn.Tensor] but not sure
-        multiplier: float,
-    ):
-        if type(module) == torch.nn.Conv2d:
-            op = torch.nn.functional.conv2d
-            extra_args = dict(
-                stride=module.stride,
-                padding=module.padding,
-                dilation=module.dilation,
-                groups=module.groups,
-            )
-
-        else:
-            op = torch.nn.functional.linear
-            extra_args = {}
-
-        weight = self.get_weight()
-
-        bias = self.bias if self.bias is not None else 0
-        scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
-        return (
-            op(
-                *input_h,
-                (weight + bias).view(module.weight.shape),
-                None,
-                **extra_args,
-            )
-            * multiplier
-            * scale
-        )
-
-    def get_weight(self):
+    def get_weight(self, orig_weight: torch.Tensor):
         raise NotImplementedError()
 
     def calc_size(self) -> int:
@@ -197,7 +163,7 @@ class LoRALayer(LoRALayerBase):
 
         self.rank = self.down.shape[0]
 
-    def get_weight(self):
+    def get_weight(self, orig_weight: torch.Tensor):
         if self.mid is not None:
             up = self.up.reshape(self.up.shape[0], self.up.shape[1])
             down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@@ -260,7 +226,7 @@ class LoHALayer(LoRALayerBase):
 
         self.rank = self.w1_b.shape[0]
 
-    def get_weight(self):
+    def get_weight(self, orig_weight: torch.Tensor):
         if self.t1 is None:
             weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
 
@@ -342,7 +308,7 @@ class LoKRLayer(LoRALayerBase):
         else:
             self.rank = None  # unscaled
 
-    def get_weight(self):
+    def get_weight(self, orig_weight: torch.Tensor):
         w1 = self.w1
         if w1 is None:
             w1 = self.w1_a @ self.w1_b
@@ -410,7 +376,7 @@ class FullLayer(LoRALayerBase):
 
         self.rank = None  # unscaled
 
-    def get_weight(self):
+    def get_weight(self, orig_weight: torch.Tensor):
         return self.weight
 
     def calc_size(self) -> int:
@@ -427,6 +393,44 @@ class FullLayer(LoRALayerBase):
 
         self.weight = self.weight.to(device=device, dtype=dtype)
 
+class IA3Layer(LoRALayerBase):
+    # weight: torch.Tensor
+    # on_input: torch.Tensor
+
+    def __init__(
+        self,
+        layer_key: str,
+        values: dict,
+    ):
+        super().__init__(layer_key, values)
+
+        self.weight = values["weight"]
+        self.on_input = values["on_input"]
+
+        self.rank = None  # unscaled
+
+    def get_weight(self, orig_weight: torch.Tensor):
+        weight = self.weight
+        if not self.on_input:
+            weight = weight.reshape(-1, 1)
+        return orig_weight * weight
+
+    def calc_size(self) -> int:
+        model_size = super().calc_size()
+        model_size += self.weight.nelement() * self.weight.element_size()
+        model_size += self.on_input.nelement() * self.on_input.element_size()
+        return model_size
+
+    def to(
+        self,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        super().to(device=device, dtype=dtype)
+
+        self.weight = self.weight.to(device=device, dtype=dtype)
+        self.on_input = self.on_input.to(device=device, dtype=dtype)
+
 
 # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
 class LoRAModelRaw:  # (torch.nn.Module):
@@ -547,11 +551,15 @@ class LoRAModelRaw:  # (torch.nn.Module):
             elif "lokr_w1_b" in values or "lokr_w1" in values:
                 layer = LoKRLayer(layer_key, values)
 
+            # diff
             elif "diff" in values:
                 layer = FullLayer(layer_key, values)
 
+            # ia3
+            elif "weight" in values and "on_input" in values:
+                layer = IA3Layer(layer_key, values)
+
             else:
-                # TODO: ia3/... format
                 print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
                 raise Exception("Unknown lora format!")
 
diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py
index f435ab79b6..0702224bc7 100644
--- a/invokeai/backend/model_management/util.py
+++ b/invokeai/backend/model_management/util.py
@@ -12,37 +12,43 @@ def lora_token_vector_length(checkpoint: dict) -> int:
     def _get_shape_1(key, tensor, checkpoint):
         lora_token_vector_length = None
 
+        if "." not in key:
+            return lora_token_vector_length # wrong key format
+        model_key, lora_key = key.split(".", 1)
+
         # check lora/locon
-        if ".lora_down.weight" in key:
+        if lora_key == "lora_down.weight":
             lora_token_vector_length = tensor.shape[1]
 
         # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
-        elif ".hada_w1_b" in key or ".hada_w2_b" in key:
+        elif lora_key in ["hada_w1_b", "hada_w2_b"]:
             lora_token_vector_length = tensor.shape[1]
 
         # check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
-        elif ".lokr_" in key:
-            _lokr_key = key.split(".")[0]
-
-            if _lokr_key + ".lokr_w1" in checkpoint:
-                _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
-            elif _lokr_key + "lokr_w1_b" in checkpoint:
-                _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
+        elif "lokr_" in lora_key:
+            if model_key + ".lokr_w1" in checkpoint:
+                _lokr_w1 = checkpoint[model_key + ".lokr_w1"]
+            elif model_key + "lokr_w1_b" in checkpoint:
+                _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
             else:
                 return lora_token_vector_length  # unknown format
 
-            if _lokr_key + ".lokr_w2" in checkpoint:
-                _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
-            elif _lokr_key + "lokr_w2_b" in checkpoint:
-                _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
+            if model_key + ".lokr_w2" in checkpoint:
+                _lokr_w2 = checkpoint[model_key + ".lokr_w2"]
+            elif model_key + "lokr_w2_b" in checkpoint:
+                _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
             else:
                 return lora_token_vector_length  # unknown format
 
             lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
 
-        elif ".diff" in key:
+        elif lora_key == "diff":
             lora_token_vector_length = tensor.shape[1]
 
+        # ia3 can be detected only by shape[0] in text encoder
+        elif lora_key == "weight" and "lora_unet_" not in model_key:
+            lora_token_vector_length = tensor.shape[0]
+
         return lora_token_vector_length
 
     lora_token_vector_length = None