mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix merge conflict resolve - restore full/diff layer support
This commit is contained in:
parent
0e9f92b868
commit
f0613bb0ef
@ -391,6 +391,43 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
class LoRAModelRaw: # (torch.nn.Module):
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
@ -510,10 +547,13 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
elif "diff" in values:
|
||||||
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO: diff/ia3/... format
|
# TODO: ia3/... format
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
return
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
state_dict[layer_key].clear()
|
state_dict[layer_key].clear()
|
||||||
@ -536,6 +576,8 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
return state_dict_groupped
|
return state_dict_groupped
|
||||||
|
|
||||||
|
|
||||||
|
# code from
|
||||||
|
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
def make_sdxl_unet_conversion_map():
|
def make_sdxl_unet_conversion_map():
|
||||||
unet_conversion_map_layer = []
|
unet_conversion_map_layer = []
|
||||||
|
|
||||||
@ -620,7 +662,6 @@ def make_sdxl_unet_conversion_map():
|
|||||||
return unet_conversion_map
|
return unet_conversion_map
|
||||||
|
|
||||||
|
|
||||||
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
|
||||||
SDXL_UNET_COMPVIS_MAP = {
|
SDXL_UNET_COMPVIS_MAP = {
|
||||||
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||||
for sd, hf in make_sdxl_unet_conversion_map()
|
for sd, hf in make_sdxl_unet_conversion_map()
|
||||||
|
Loading…
Reference in New Issue
Block a user