Merge branch 'main' into feat/refactor_generation_backend

This commit is contained in:
blessedcoolant
2023-08-11 20:53:38 +12:00
33 changed files with 385 additions and 187 deletions

View File

@ -21,7 +21,6 @@ from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints
from urllib import request
import npyscreen
@ -396,13 +395,23 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
max_width=80,
scroll_exit=True,
)
self.max_cache_size = self.add_widget_intelligent(
IntTitleSlider,
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
value=old_opts.max_cache_size,
out_of=MAX_RAM,
lowest=3,
begin_entry_at=6,
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
out_of=round(MAX_RAM),
lowest=0.0,
step=0.5,
relx=8,
scroll_exit=True,
)
if HAS_CUDA:
@ -418,7 +427,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.nextrely -= 1
self.max_vram_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=old_opts.max_vram_cache_size,
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
out_of=round(MAX_VRAM * 2) / 2,
lowest=0.0,
relx=8,
@ -596,6 +605,16 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
)
# -------------------------------------
def clip(value: float, range: tuple[float, float], step: float) -> float:
minimum, maximum = range
if value < minimum:
value = minimum
if value > maximum:
value = maximum
return round(value / step) * step
# -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("Initializing InvokeAI runtime directory")

View File

@ -591,7 +591,6 @@ script, which will perform a full upgrade in place.""",
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)

View File

@ -143,7 +143,7 @@ class ModelPatcher:
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
@ -361,7 +361,8 @@ class ONNXModelPatcher:
layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
# TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:

View File

@ -526,7 +526,7 @@ class ModelManager(object):
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True

View File

@ -17,6 +17,7 @@ from .models import (
SilenceWarnings,
InvalidModelException,
)
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta
@ -315,38 +316,16 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
# misclassified as SD-1
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
if key in checkpoint and checkpoint[key].shape[0] == 320:
return BaseModelType.StableDiffusion2
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
if key in checkpoint:
return BaseModelType.StableDiffusionXL
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[1]
if key2 in checkpoint
else checkpoint[key3].shape[0]
if key3 in checkpoint
else None
)
if lora_token_vector_length == 768:
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type")
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@ -122,41 +122,7 @@ class LoRALayerBase:
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
raise NotImplementedError()
def calc_size(self) -> int:
@ -197,7 +163,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0]
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -260,7 +226,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0]
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -342,7 +308,7 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
@ -410,7 +376,7 @@ class FullLayer(LoRALayerBase):
self.rank = None # unscaled
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
return self.weight
def calc_size(self) -> int:
@ -428,6 +394,45 @@ class FullLayer(LoRALayerBase):
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
@ -547,11 +552,15 @@ class LoRAModelRaw: # (torch.nn.Module):
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")

View File

@ -0,0 +1,75 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length