2023-05-29 22:11:00 +00:00
from __future__ import annotations
2023-05-30 23:12:27 +00:00
import copy
2023-05-29 22:11:00 +00:00
from contextlib import contextmanager
2023-07-05 23:18:25 +00:00
from typing import Optional , Dict , Tuple , Any , Union , List
2023-07-05 02:37:16 +00:00
from pathlib import Path
2023-05-29 22:11:00 +00:00
import torch
from safetensors . torch import load_file
from torch . utils . hooks import RemovableHandle
from diffusers . models import UNet2DConditionModel
from transformers import CLIPTextModel
2023-06-20 23:12:21 +00:00
from onnx import numpy_helper
2023-06-22 17:03:17 +00:00
from onnxruntime import OrtValue
2023-06-20 23:12:21 +00:00
import numpy as np
2023-05-29 22:11:00 +00:00
2023-05-30 23:12:27 +00:00
from compel . embeddings_provider import BaseTextualInversionManager
2023-07-05 02:37:16 +00:00
from diffusers . models import UNet2DConditionModel
2023-05-29 22:11:00 +00:00
from safetensors . torch import load_file
2023-07-05 20:40:47 +00:00
from transformers import CLIPTextModel , CLIPTokenizer
2023-05-30 23:12:27 +00:00
2023-06-21 01:24:25 +00:00
# TODO: rename and split this file
2023-07-28 13:46:44 +00:00
2023-05-29 22:11:00 +00:00
class LoRALayerBase :
2023-07-28 13:46:44 +00:00
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
2023-05-29 22:11:00 +00:00
2023-07-28 13:46:44 +00:00
# @property
# def scale(self):
2023-05-29 22:11:00 +00:00
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__ (
self ,
layer_key : str ,
values : dict ,
) :
if " alpha " in values :
self . alpha = values [ " alpha " ] . item ( )
else :
self . alpha = None
2023-07-28 13:46:44 +00:00
if " bias_indices " in values and " bias_values " in values and " bias_size " in values :
2023-05-29 22:11:00 +00:00
self . bias = torch . sparse_coo_tensor (
values [ " bias_indices " ] ,
values [ " bias_values " ] ,
tuple ( values [ " bias_size " ] ) ,
)
else :
self . bias = None
2023-07-28 13:46:44 +00:00
self . rank = None # set in layer implementation
2023-05-29 22:11:00 +00:00
self . layer_key = layer_key
def forward (
self ,
module : torch . nn . Module ,
2023-07-28 13:46:44 +00:00
input_h : Any , # for real looks like Tuple[torch.nn.Tensor] but not sure
2023-05-29 22:11:00 +00:00
multiplier : float ,
) :
if type ( module ) == torch . nn . Conv2d :
op = torch . nn . functional . conv2d
extra_args = dict (
stride = module . stride ,
padding = module . padding ,
dilation = module . dilation ,
groups = module . groups ,
)
else :
op = torch . nn . functional . linear
extra_args = { }
2023-06-20 23:12:21 +00:00
weight = self . get_weight ( )
2023-05-29 22:11:00 +00:00
bias = self . bias if self . bias is not None else 0
scale = self . alpha / self . rank if ( self . alpha and self . rank ) else 1.0
2023-07-28 13:46:44 +00:00
return (
op (
* input_h ,
( weight + bias ) . view ( module . weight . shape ) ,
None ,
* * extra_args ,
)
* multiplier
* scale
)
2023-05-29 22:11:00 +00:00
2023-06-20 23:12:21 +00:00
def get_weight ( self ) :
2023-05-29 22:11:00 +00:00
raise NotImplementedError ( )
def calc_size ( self ) - > int :
model_size = 0
for val in [ self . bias ] :
if val is not None :
model_size + = val . nelement ( ) * val . element_size ( )
return model_size
def to (
self ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
if self . bias is not None :
self . bias = self . bias . to ( device = device , dtype = dtype )
# TODO: find and debug lora/locon with bias
class LoRALayer ( LoRALayerBase ) :
2023-07-28 13:46:44 +00:00
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
2023-05-29 22:11:00 +00:00
def __init__ (
self ,
layer_key : str ,
values : dict ,
) :
super ( ) . __init__ ( layer_key , values )
self . up = values [ " lora_up.weight " ]
self . down = values [ " lora_down.weight " ]
if " lora_mid.weight " in values :
self . mid = values [ " lora_mid.weight " ]
else :
self . mid = None
self . rank = self . down . shape [ 0 ]
2023-06-20 23:12:21 +00:00
def get_weight ( self ) :
2023-05-29 22:11:00 +00:00
if self . mid is not None :
2023-07-05 20:40:47 +00:00
up = self . up . reshape ( self . up . shape [ 0 ] , self . up . shape [ 1 ] )
down = self . down . reshape ( self . down . shape [ 0 ] , self . down . shape [ 1 ] )
2023-05-29 22:11:00 +00:00
weight = torch . einsum ( " m n w h, i m, n j -> i j w h " , self . mid , up , down )
else :
weight = self . up . reshape ( self . up . shape [ 0 ] , - 1 ) @ self . down . reshape ( self . down . shape [ 0 ] , - 1 )
return weight
def calc_size ( self ) - > int :
model_size = super ( ) . calc_size ( )
for val in [ self . up , self . mid , self . down ] :
if val is not None :
model_size + = val . nelement ( ) * val . element_size ( )
return model_size
def to (
self ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
super ( ) . to ( device = device , dtype = dtype )
self . up = self . up . to ( device = device , dtype = dtype )
self . down = self . down . to ( device = device , dtype = dtype )
if self . mid is not None :
self . mid = self . mid . to ( device = device , dtype = dtype )
class LoHALayer ( LoRALayerBase ) :
2023-07-28 13:46:44 +00:00
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
2023-05-29 22:11:00 +00:00
def __init__ (
self ,
layer_key : str ,
values : dict ,
) :
2023-06-26 01:33:37 +00:00
super ( ) . __init__ ( layer_key , values )
2023-05-29 22:11:00 +00:00
self . w1_a = values [ " hada_w1_a " ]
self . w1_b = values [ " hada_w1_b " ]
self . w2_a = values [ " hada_w2_a " ]
self . w2_b = values [ " hada_w2_b " ]
if " hada_t1 " in values :
self . t1 = values [ " hada_t1 " ]
else :
self . t1 = None
if " hada_t2 " in values :
self . t2 = values [ " hada_t2 " ]
else :
self . t2 = None
self . rank = self . w1_b . shape [ 0 ]
2023-06-20 23:12:21 +00:00
def get_weight ( self ) :
2023-05-29 22:11:00 +00:00
if self . t1 is None :
weight = ( self . w1_a @ self . w1_b ) * ( self . w2_a @ self . w2_b )
else :
2023-07-28 13:46:44 +00:00
rebuild1 = torch . einsum ( " i j k l, j r, i p -> p r k l " , self . t1 , self . w1_b , self . w1_a )
rebuild2 = torch . einsum ( " i j k l, j r, i p -> p r k l " , self . t2 , self . w2_b , self . w2_a )
2023-05-29 22:11:00 +00:00
weight = rebuild1 * rebuild2
return weight
def calc_size ( self ) - > int :
model_size = super ( ) . calc_size ( )
for val in [ self . w1_a , self . w1_b , self . w2_a , self . w2_b , self . t1 , self . t2 ] :
if val is not None :
model_size + = val . nelement ( ) * val . element_size ( )
return model_size
def to (
self ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
super ( ) . to ( device = device , dtype = dtype )
self . w1_a = self . w1_a . to ( device = device , dtype = dtype )
self . w1_b = self . w1_b . to ( device = device , dtype = dtype )
if self . t1 is not None :
self . t1 = self . t1 . to ( device = device , dtype = dtype )
self . w2_a = self . w2_a . to ( device = device , dtype = dtype )
self . w2_b = self . w2_b . to ( device = device , dtype = dtype )
if self . t2 is not None :
self . t2 = self . t2 . to ( device = device , dtype = dtype )
class LoKRLayer ( LoRALayerBase ) :
2023-07-28 13:46:44 +00:00
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
2023-05-29 22:11:00 +00:00
def __init__ (
self ,
layer_key : str ,
values : dict ,
) :
2023-07-28 13:46:44 +00:00
super ( ) . __init__ ( layer_key , values )
2023-05-29 22:11:00 +00:00
if " lokr_w1 " in values :
self . w1 = values [ " lokr_w1 " ]
self . w1_a = None
self . w1_b = None
else :
self . w1 = None
self . w1_a = values [ " lokr_w1_a " ]
self . w1_b = values [ " lokr_w1_b " ]
if " lokr_w2 " in values :
self . w2 = values [ " lokr_w2 " ]
self . w2_a = None
self . w2_b = None
else :
self . w2 = None
self . w2_a = values [ " lokr_w2_a " ]
self . w2_b = values [ " lokr_w2_b " ]
if " lokr_t2 " in values :
self . t2 = values [ " lokr_t2 " ]
else :
self . t2 = None
if " lokr_w1_b " in values :
self . rank = values [ " lokr_w1_b " ] . shape [ 0 ]
elif " lokr_w2_b " in values :
self . rank = values [ " lokr_w2_b " ] . shape [ 0 ]
else :
2023-07-28 13:46:44 +00:00
self . rank = None # unscaled
2023-05-29 22:11:00 +00:00
2023-06-20 23:12:21 +00:00
def get_weight ( self ) :
2023-05-29 22:11:00 +00:00
w1 = self . w1
if w1 is None :
w1 = self . w1_a @ self . w1_b
w2 = self . w2
if w2 is None :
if self . t2 is None :
w2 = self . w2_a @ self . w2_b
else :
2023-07-28 13:46:44 +00:00
w2 = torch . einsum ( " i j k l, i p, j r -> p r k l " , self . t2 , self . w2_a , self . w2_b )
2023-05-29 22:11:00 +00:00
if len ( w2 . shape ) == 4 :
w1 = w1 . unsqueeze ( 2 ) . unsqueeze ( 2 )
w2 = w2 . contiguous ( )
2023-06-26 00:57:33 +00:00
weight = torch . kron ( w1 , w2 )
2023-05-29 22:11:00 +00:00
return weight
def calc_size ( self ) - > int :
model_size = super ( ) . calc_size ( )
for val in [ self . w1 , self . w1_a , self . w1_b , self . w2 , self . w2_a , self . w2_b , self . t2 ] :
if val is not None :
model_size + = val . nelement ( ) * val . element_size ( )
return model_size
def to (
self ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
super ( ) . to ( device = device , dtype = dtype )
if self . w1 is not None :
self . w1 = self . w1 . to ( device = device , dtype = dtype )
else :
self . w1_a = self . w1_a . to ( device = device , dtype = dtype )
self . w1_b = self . w1_b . to ( device = device , dtype = dtype )
if self . w2 is not None :
self . w2 = self . w2 . to ( device = device , dtype = dtype )
else :
self . w2_a = self . w2_a . to ( device = device , dtype = dtype )
self . w2_b = self . w2_b . to ( device = device , dtype = dtype )
if self . t2 is not None :
self . t2 = self . t2 . to ( device = device , dtype = dtype )
2023-08-01 14:02:57 +00:00
class FullLayer ( LoRALayerBase ) :
# weight: torch.Tensor
def __init__ (
self ,
layer_key : str ,
values : dict ,
) :
super ( ) . __init__ ( layer_key , values )
self . weight = values [ " diff " ]
if len ( values . keys ( ) ) > 1 :
_keys = list ( values . keys ( ) )
_keys . remove ( " diff " )
raise NotImplementedError ( f " Unexpected keys in lora diff layer: { _keys } " )
self . rank = None # unscaled
def get_weight ( self ) :
return self . weight
def calc_size ( self ) - > int :
model_size = super ( ) . calc_size ( )
model_size + = self . weight . nelement ( ) * self . weight . element_size ( )
return model_size
def to (
self ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
super ( ) . to ( device = device , dtype = dtype )
self . weight = self . weight . to ( device = device , dtype = dtype )
2023-07-28 13:46:44 +00:00
class LoRAModel : # (torch.nn.Module):
2023-05-29 22:11:00 +00:00
_name : str
layers : Dict [ str , LoRALayer ]
_device : torch . device
_dtype : torch . dtype
def __init__ (
self ,
name : str ,
layers : Dict [ str , LoRALayer ] ,
device : torch . device ,
dtype : torch . dtype ,
) :
self . _name = name
self . _device = device or torch . cpu
self . _dtype = dtype or torch . float32
self . layers = layers
@property
def name ( self ) :
return self . _name
@property
def device ( self ) :
return self . _device
@property
def dtype ( self ) :
2023-07-28 13:46:44 +00:00
return self . _dtype
2023-05-29 22:11:00 +00:00
def to (
self ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) - > LoRAModel :
# TODO: try revert if exception?
for key , layer in self . layers . items ( ) :
layer . to ( device = device , dtype = dtype )
self . _device = device
self . _dtype = dtype
def calc_size ( self ) - > int :
model_size = 0
for _ , layer in self . layers . items ( ) :
model_size + = layer . calc_size ( )
return model_size
@classmethod
def from_checkpoint (
cls ,
file_path : Union [ str , Path ] ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
device = device or torch . device ( " cpu " )
dtype = dtype or torch . float32
if isinstance ( file_path , str ) :
file_path = Path ( file_path )
model = cls (
device = device ,
dtype = dtype ,
2023-07-28 13:46:44 +00:00
name = file_path . stem , # TODO:
2023-05-29 22:11:00 +00:00
layers = dict ( ) ,
)
if file_path . suffix == " .safetensors " :
state_dict = load_file ( file_path . absolute ( ) . as_posix ( ) , device = " cpu " )
else :
state_dict = torch . load ( file_path , map_location = " cpu " )
state_dict = cls . _group_state ( state_dict )
for layer_key , values in state_dict . items ( ) :
# lora and locon
if " lora_down.weight " in values :
layer = LoRALayer ( layer_key , values )
# loha
elif " hada_w1_b " in values :
layer = LoHALayer ( layer_key , values )
# lokr
elif " lokr_w1_b " in values or " lokr_w1 " in values :
layer = LoKRLayer ( layer_key , values )
2023-08-01 14:02:57 +00:00
elif " diff " in values :
layer = FullLayer ( layer_key , values )
2023-05-29 22:11:00 +00:00
else :
2023-08-01 14:02:57 +00:00
# TODO: ia3/... format
print ( f " >> Encountered unknown lora layer module in { model . name } : { layer_key } - { list ( values . keys ( ) ) } " )
raise Exception ( " Unknown lora format! " )
2023-05-29 22:11:00 +00:00
# lower memory consumption by removing already parsed layer values
state_dict [ layer_key ] . clear ( )
layer . to ( device = device , dtype = dtype )
model . layers [ layer_key ] = layer
return model
@staticmethod
def _group_state ( state_dict : dict ) :
state_dict_groupped = dict ( )
for key , value in state_dict . items ( ) :
stem , leaf = key . split ( " . " , 1 )
if stem not in state_dict_groupped :
state_dict_groupped [ stem ] = dict ( )
state_dict_groupped [ stem ] [ leaf ] = value
return state_dict_groupped
"""
loras = [
( lora_model1 , 0.7 ) ,
( lora_model2 , 0.4 ) ,
]
with LoRAHelper . apply_lora_unet ( unet , loras ) :
# unet with applied loras
# unmodified unet
"""
2023-07-28 13:46:44 +00:00
2023-05-29 22:11:00 +00:00
# TODO: rename smth like ModelPatcher and add TI method?
2023-05-30 23:12:27 +00:00
class ModelPatcher :
2023-05-29 22:11:00 +00:00
@staticmethod
def _resolve_lora_key ( model : torch . nn . Module , lora_key : str , prefix : str ) - > Tuple [ str , torch . nn . Module ] :
assert " . " not in lora_key
if not lora_key . startswith ( prefix ) :
raise Exception ( f " lora_key with invalid prefix: { lora_key } , { prefix } " )
module = model
module_key = " "
2023-07-28 13:46:44 +00:00
key_parts = lora_key [ len ( prefix ) : ] . split ( " _ " )
2023-05-29 22:11:00 +00:00
submodule_name = key_parts . pop ( 0 )
2023-07-28 13:46:44 +00:00
2023-05-29 22:11:00 +00:00
while len ( key_parts ) > 0 :
try :
module = module . get_submodule ( submodule_name )
module_key + = " . " + submodule_name
submodule_name = key_parts . pop ( 0 )
except :
submodule_name + = " _ " + key_parts . pop ( 0 )
module = module . get_submodule ( submodule_name )
2023-06-26 00:57:33 +00:00
module_key = ( module_key + " . " + submodule_name ) . lstrip ( " . " )
2023-05-29 22:11:00 +00:00
return ( module_key , module )
@staticmethod
def _lora_forward_hook (
2023-07-24 06:58:24 +00:00
applied_loras : List [ Tuple [ LoRAModel , float ] ] ,
2023-05-29 22:11:00 +00:00
layer_name : str ,
) :
def lora_forward ( module , input_h , output ) :
if len ( applied_loras ) == 0 :
return output
for lora , weight in applied_loras :
layer = lora . layers . get ( layer_name , None )
if layer is None :
continue
output + = layer . forward ( module , input_h , weight )
return output
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet (
cls ,
unet : UNet2DConditionModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( unet , loras , " lora_unet_ " ) :
yield
@classmethod
@contextmanager
def apply_lora_text_encoder (
cls ,
text_encoder : CLIPTextModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( text_encoder , loras , " lora_te_ " ) :
yield
@classmethod
@contextmanager
def apply_lora (
cls ,
model : torch . nn . Module ,
2023-07-24 06:58:24 +00:00
loras : List [ Tuple [ LoRAModel , float ] ] ,
2023-05-29 22:11:00 +00:00
prefix : str ,
) :
2023-06-26 00:57:33 +00:00
original_weights = dict ( )
2023-05-29 22:11:00 +00:00
try :
2023-07-05 04:39:15 +00:00
with torch . no_grad ( ) :
2023-06-26 00:57:33 +00:00
for lora , lora_weight in loras :
2023-07-28 13:46:44 +00:00
# assert lora.device.type == "cpu"
2023-06-26 00:57:33 +00:00
for layer_key , layer in lora . layers . items ( ) :
if not layer_key . startswith ( prefix ) :
continue
2023-05-29 22:11:00 +00:00
2023-06-26 00:57:33 +00:00
module_key , module = cls . _resolve_lora_key ( model , layer_key , prefix )
if module_key not in original_weights :
original_weights [ module_key ] = module . weight . detach ( ) . to ( device = " cpu " , copy = True )
# enable autocast to calc fp16 loras on cpu
2023-07-28 13:46:44 +00:00
# with torch.autocast(device_type="cpu"):
2023-07-05 02:37:16 +00:00
layer . to ( dtype = torch . float32 )
layer_scale = layer . alpha / layer . rank if ( layer . alpha and layer . rank ) else 1.0
layer_weight = layer . get_weight ( ) * lora_weight * layer_scale
2023-06-26 00:57:33 +00:00
if module . weight . shape != layer_weight . shape :
# TODO: debug on lycoris
layer_weight = layer_weight . reshape ( module . weight . shape )
module . weight + = layer_weight . to ( device = module . weight . device , dtype = module . weight . dtype )
2023-05-29 22:11:00 +00:00
2023-07-28 13:46:44 +00:00
yield # wait for context manager exit
2023-05-29 22:11:00 +00:00
finally :
2023-07-05 04:39:15 +00:00
with torch . no_grad ( ) :
2023-06-26 00:57:33 +00:00
for module_key , weight in original_weights . items ( ) :
model . get_submodule ( module_key ) . weight . copy_ ( weight )
2023-05-30 23:12:27 +00:00
@classmethod
@contextmanager
def apply_ti (
cls ,
tokenizer : CLIPTokenizer ,
text_encoder : CLIPTextModel ,
ti_list : List [ Any ] ,
) - > Tuple [ CLIPTokenizer , TextualInversionManager ] :
init_tokens_count = None
new_tokens_added = None
try :
ti_tokenizer = copy . deepcopy ( tokenizer )
2023-06-17 16:20:24 +00:00
ti_manager = TextualInversionManager ( ti_tokenizer )
2023-05-30 23:12:27 +00:00
init_tokens_count = text_encoder . resize_token_embeddings ( None ) . num_embeddings
def _get_trigger ( ti , index ) :
trigger = ti . name
if index > 0 :
trigger + = f " -!pad- { i } "
return f " < { trigger } > "
# modify tokenizer
new_tokens_added = 0
for ti in ti_list :
for i in range ( ti . embedding . shape [ 0 ] ) :
new_tokens_added + = ti_tokenizer . add_tokens ( _get_trigger ( ti , i ) )
# modify text_encoder
text_encoder . resize_token_embeddings ( init_tokens_count + new_tokens_added )
model_embeddings = text_encoder . get_input_embeddings ( )
for ti in ti_list :
ti_tokens = [ ]
for i in range ( ti . embedding . shape [ 0 ] ) :
embedding = ti . embedding [ i ]
trigger = _get_trigger ( ti , i )
token_id = ti_tokenizer . convert_tokens_to_ids ( trigger )
if token_id == ti_tokenizer . unk_token_id :
raise RuntimeError ( f " Unable to find token id for token ' { trigger } ' " )
if model_embeddings . weight . data [ token_id ] . shape != embedding . shape :
raise ValueError (
f " Cannot load embedding for { trigger } . It was trained on a model with token dimension { embedding . shape [ 0 ] } , but the current model has token dimension { model_embeddings . weight . data [ token_id ] . shape [ 0 ] } . "
)
2023-07-28 13:46:44 +00:00
model_embeddings . weight . data [ token_id ] = embedding . to (
device = text_encoder . device , dtype = text_encoder . dtype
)
2023-05-30 23:12:27 +00:00
ti_tokens . append ( token_id )
if len ( ti_tokens ) > 1 :
ti_manager . pad_tokens [ ti_tokens [ 0 ] ] = ti_tokens [ 1 : ]
yield ti_tokenizer , ti_manager
finally :
if init_tokens_count and new_tokens_added :
text_encoder . resize_token_embeddings ( init_tokens_count )
2023-07-06 13:09:40 +00:00
@classmethod
@contextmanager
def apply_clip_skip (
cls ,
text_encoder : CLIPTextModel ,
clip_skip : int ,
) :
skipped_layers = [ ]
try :
for i in range ( clip_skip ) :
skipped_layers . append ( text_encoder . text_model . encoder . layers . pop ( - 1 ) )
yield
finally :
while len ( skipped_layers ) > 0 :
text_encoder . text_model . encoder . layers . append ( skipped_layers . pop ( ) )
2023-07-28 13:46:44 +00:00
2023-05-30 23:12:27 +00:00
class TextualInversionModel :
name : str
2023-07-28 13:46:44 +00:00
embedding : torch . Tensor # [n, 768]|[n, 1280]
2023-05-30 23:12:27 +00:00
@classmethod
def from_checkpoint (
cls ,
file_path : Union [ str , Path ] ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
if not isinstance ( file_path , Path ) :
file_path = Path ( file_path )
2023-07-28 13:46:44 +00:00
result = cls ( ) # TODO:
result . name = file_path . stem # TODO:
2023-05-30 23:12:27 +00:00
if file_path . suffix == " .safetensors " :
state_dict = load_file ( file_path . absolute ( ) . as_posix ( ) , device = " cpu " )
else :
state_dict = torch . load ( file_path , map_location = " cpu " )
# both v1 and v2 format embeddings
# difference mostly in metadata
if " string_to_param " in state_dict :
if len ( state_dict [ " string_to_param " ] ) > 1 :
2023-07-28 13:46:44 +00:00
print (
f ' Warn: Embedding " { file_path . name } " contains multiple tokens, which is not supported. The first token will be used. '
)
2023-05-30 23:12:27 +00:00
result . embedding = next ( iter ( state_dict [ " string_to_param " ] . values ( ) ) )
# v3 (easynegative)
elif " emb_params " in state_dict :
result . embedding = state_dict [ " emb_params " ]
# v4(diffusers bin files)
else :
result . embedding = next ( iter ( state_dict . values ( ) ) )
2023-07-05 16:46:00 +00:00
if len ( result . embedding . shape ) == 1 :
result . embedding = result . embedding . unsqueeze ( 0 )
2023-05-30 23:12:27 +00:00
if not isinstance ( result . embedding , torch . Tensor ) :
raise ValueError ( f " Invalid embeddings file: { file_path . name } " )
return result
class TextualInversionManager ( BaseTextualInversionManager ) :
pad_tokens : Dict [ int , List [ int ] ]
2023-06-17 16:20:24 +00:00
tokenizer : CLIPTokenizer
2023-05-30 23:12:27 +00:00
2023-06-17 16:20:24 +00:00
def __init__ ( self , tokenizer : CLIPTokenizer ) :
2023-05-30 23:12:27 +00:00
self . pad_tokens = dict ( )
2023-06-17 16:20:24 +00:00
self . tokenizer = tokenizer
2023-05-30 23:12:27 +00:00
2023-07-28 13:46:44 +00:00
def expand_textual_inversion_token_ids_if_necessary ( self , token_ids : list [ int ] ) - > list [ int ] :
2023-05-30 23:12:27 +00:00
if len ( self . pad_tokens ) == 0 :
return token_ids
2023-06-17 16:20:24 +00:00
if token_ids [ 0 ] == self . tokenizer . bos_token_id :
raise ValueError ( " token_ids must not start with bos_token_id " )
if token_ids [ - 1 ] == self . tokenizer . eos_token_id :
raise ValueError ( " token_ids must not end with eos_token_id " )
2023-05-30 23:12:27 +00:00
new_token_ids = [ ]
for token_id in token_ids :
new_token_ids . append ( token_id )
if token_id in self . pad_tokens :
new_token_ids . extend ( self . pad_tokens [ token_id ] )
return new_token_ids
2023-06-20 23:12:21 +00:00
class ONNXModelPatcher :
2023-07-28 13:59:35 +00:00
from . models . base import IAIOnnxRuntimeModel , OnnxRuntimeModel
2023-07-28 14:00:09 +00:00
2023-06-20 23:12:21 +00:00
@classmethod
@contextmanager
def apply_lora_unet (
cls ,
unet : OnnxRuntimeModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( unet , loras , " lora_unet_ " ) :
yield
@classmethod
@contextmanager
def apply_lora_text_encoder (
cls ,
text_encoder : OnnxRuntimeModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( text_encoder , loras , " lora_te_ " ) :
yield
2023-06-21 01:24:25 +00:00
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
2023-06-20 23:12:21 +00:00
@classmethod
@contextmanager
def apply_lora (
cls ,
model : IAIOnnxRuntimeModel ,
loras : List [ Tuple [ LoraModel , float ] ] ,
prefix : str ,
) :
from . models . base import IAIOnnxRuntimeModel
2023-07-28 13:46:44 +00:00
2023-06-20 23:12:21 +00:00
if not isinstance ( model , IAIOnnxRuntimeModel ) :
raise Exception ( " Only IAIOnnxRuntimeModel models supported " )
2023-06-22 17:03:17 +00:00
orig_weights = dict ( )
2023-06-20 23:12:21 +00:00
try :
blended_loras = dict ( )
for lora , lora_weight in loras :
for layer_key , layer in lora . layers . items ( ) :
if not layer_key . startswith ( prefix ) :
continue
2023-07-20 18:02:23 +00:00
layer . to ( dtype = torch . float32 )
2023-06-20 23:12:21 +00:00
layer_key = layer_key . replace ( prefix , " " )
layer_weight = layer . get_weight ( ) . detach ( ) . cpu ( ) . numpy ( ) * lora_weight
if layer_key is blended_loras :
blended_loras [ layer_key ] + = layer_weight
else :
blended_loras [ layer_key ] = layer_weight
2023-06-22 17:03:17 +00:00
node_names = dict ( )
for node in model . nodes . values ( ) :
node_names [ node . name . replace ( " / " , " _ " ) . replace ( " . " , " _ " ) . lstrip ( " _ " ) ] = node . name
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
for layer_key , lora_weight in blended_loras . items ( ) :
2023-06-20 23:12:21 +00:00
conv_key = layer_key + " _Conv "
gemm_key = layer_key + " _Gemm "
matmul_key = layer_key + " _MatMul "
2023-06-22 17:03:17 +00:00
if conv_key in node_names or gemm_key in node_names :
if conv_key in node_names :
conv_node = model . nodes [ node_names [ conv_key ] ]
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
conv_node = model . nodes [ node_names [ gemm_key ] ]
2023-06-20 23:12:21 +00:00
weight_name = [ n for n in conv_node . input if " .weight " in n ] [ 0 ]
2023-06-22 17:03:17 +00:00
orig_weight = model . tensors [ weight_name ]
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
if orig_weight . shape [ - 2 : ] == ( 1 , 1 ) :
if lora_weight . shape [ - 2 : ] == ( 1 , 1 ) :
new_weight = orig_weight . squeeze ( ( 3 , 2 ) ) + lora_weight . squeeze ( ( 3 , 2 ) )
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
new_weight = orig_weight . squeeze ( ( 3 , 2 ) ) + lora_weight
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
new_weight = np . expand_dims ( new_weight , ( 2 , 3 ) )
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
if orig_weight . shape != lora_weight . shape :
new_weight = orig_weight + lora_weight . reshape ( orig_weight . shape )
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
new_weight = orig_weight + lora_weight
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
orig_weights [ weight_name ] = orig_weight
model . tensors [ weight_name ] = new_weight . astype ( orig_weight . dtype )
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
elif matmul_key in node_names :
weight_node = model . nodes [ node_names [ matmul_key ] ]
2023-06-20 23:12:21 +00:00
matmul_name = [ n for n in weight_node . input if " MatMul " in n ] [ 0 ]
2023-06-22 17:03:17 +00:00
orig_weight = model . tensors [ matmul_name ]
new_weight = orig_weight + lora_weight . transpose ( )
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
orig_weights [ matmul_name ] = orig_weight
model . tensors [ matmul_name ] = new_weight . astype ( orig_weight . dtype )
2023-06-20 23:12:21 +00:00
else :
# warn? err?
pass
yield
finally :
# restore original weights
2023-06-22 17:03:17 +00:00
for name , orig_weight in orig_weights . items ( ) :
model . tensors [ name ] = orig_weight
2023-06-20 23:12:21 +00:00
@classmethod
@contextmanager
def apply_ti (
cls ,
tokenizer : CLIPTokenizer ,
text_encoder : IAIOnnxRuntimeModel ,
ti_list : List [ Any ] ,
) - > Tuple [ CLIPTokenizer , TextualInversionManager ] :
from . models . base import IAIOnnxRuntimeModel
2023-07-28 13:46:44 +00:00
2023-06-20 23:12:21 +00:00
if not isinstance ( text_encoder , IAIOnnxRuntimeModel ) :
raise Exception ( " Only IAIOnnxRuntimeModel models supported " )
2023-06-22 17:03:17 +00:00
orig_embeddings = None
2023-06-20 23:12:21 +00:00
try :
ti_tokenizer = copy . deepcopy ( tokenizer )
ti_manager = TextualInversionManager ( ti_tokenizer )
def _get_trigger ( ti , index ) :
trigger = ti . name
if index > 0 :
trigger + = f " -!pad- { i } "
return f " < { trigger } > "
# modify tokenizer
new_tokens_added = 0
for ti in ti_list :
for i in range ( ti . embedding . shape [ 0 ] ) :
new_tokens_added + = ti_tokenizer . add_tokens ( _get_trigger ( ti , i ) )
# modify text_encoder
2023-06-22 17:03:17 +00:00
orig_embeddings = text_encoder . tensors [ " text_model.embeddings.token_embedding.weight " ]
embeddings = np . concatenate (
2023-07-28 13:46:44 +00:00
( np . copy ( orig_embeddings ) , np . zeros ( ( new_tokens_added , orig_embeddings . shape [ 1 ] ) ) ) ,
2023-06-22 17:03:17 +00:00
axis = 0 ,
)
2023-06-20 23:12:21 +00:00
for ti in ti_list :
ti_tokens = [ ]
for i in range ( ti . embedding . shape [ 0 ] ) :
embedding = ti . embedding [ i ] . detach ( ) . numpy ( )
trigger = _get_trigger ( ti , i )
token_id = ti_tokenizer . convert_tokens_to_ids ( trigger )
if token_id == ti_tokenizer . unk_token_id :
raise RuntimeError ( f " Unable to find token id for token ' { trigger } ' " )
2023-06-22 17:03:17 +00:00
if embeddings [ token_id ] . shape != embedding . shape :
2023-06-20 23:12:21 +00:00
raise ValueError (
2023-06-22 17:03:17 +00:00
f " Cannot load embedding for { trigger } . It was trained on a model with token dimension { embedding . shape [ 0 ] } , but the current model has token dimension { embeddings [ token_id ] . shape [ 0 ] } . "
2023-06-20 23:12:21 +00:00
)
2023-06-22 17:03:17 +00:00
embeddings [ token_id ] = embedding
2023-06-20 23:12:21 +00:00
ti_tokens . append ( token_id )
if len ( ti_tokens ) > 1 :
ti_manager . pad_tokens [ ti_tokens [ 0 ] ] = ti_tokens [ 1 : ]
2023-07-28 13:46:44 +00:00
text_encoder . tensors [ " text_model.embeddings.token_embedding.weight " ] = embeddings . astype (
orig_embeddings . dtype
)
2023-06-20 23:12:21 +00:00
yield ti_tokenizer , ti_manager
finally :
# restore
2023-06-22 17:03:17 +00:00
if orig_embeddings is not None :
text_encoder . tensors [ " text_model.embeddings.token_embedding.weight " ] = orig_embeddings