2023-05-29 22:11:00 +00:00
from __future__ import annotations
2023-05-30 23:12:27 +00:00
import copy
2023-05-29 22:11:00 +00:00
from contextlib import contextmanager
2023-07-05 23:18:25 +00:00
from typing import Optional , Dict , Tuple , Any , Union , List
2023-07-05 02:37:16 +00:00
from pathlib import Path
2023-05-29 22:11:00 +00:00
import torch
from safetensors . torch import load_file
from torch . utils . hooks import RemovableHandle
from diffusers . models import UNet2DConditionModel
from transformers import CLIPTextModel
2023-06-20 23:12:21 +00:00
from onnx import numpy_helper
2023-06-22 17:03:17 +00:00
from onnxruntime import OrtValue
2023-06-20 23:12:21 +00:00
import numpy as np
2023-05-29 22:11:00 +00:00
2023-05-30 23:12:27 +00:00
from compel . embeddings_provider import BaseTextualInversionManager
2023-07-05 02:37:16 +00:00
from diffusers . models import UNet2DConditionModel
2023-05-29 22:11:00 +00:00
from safetensors . torch import load_file
2023-07-05 20:40:47 +00:00
from transformers import CLIPTextModel , CLIPTokenizer
2023-05-30 23:12:27 +00:00
2023-05-29 22:11:00 +00:00
"""
loras = [
( lora_model1 , 0.7 ) ,
( lora_model2 , 0.4 ) ,
]
with LoRAHelper . apply_lora_unet ( unet , loras ) :
# unet with applied loras
# unmodified unet
"""
2023-07-28 13:46:44 +00:00
2023-05-29 22:11:00 +00:00
# TODO: rename smth like ModelPatcher and add TI method?
2023-05-30 23:12:27 +00:00
class ModelPatcher :
2023-05-29 22:11:00 +00:00
@staticmethod
def _resolve_lora_key ( model : torch . nn . Module , lora_key : str , prefix : str ) - > Tuple [ str , torch . nn . Module ] :
assert " . " not in lora_key
if not lora_key . startswith ( prefix ) :
raise Exception ( f " lora_key with invalid prefix: { lora_key } , { prefix } " )
module = model
module_key = " "
2023-07-28 13:46:44 +00:00
key_parts = lora_key [ len ( prefix ) : ] . split ( " _ " )
2023-05-29 22:11:00 +00:00
submodule_name = key_parts . pop ( 0 )
2023-07-28 13:46:44 +00:00
2023-05-29 22:11:00 +00:00
while len ( key_parts ) > 0 :
try :
module = module . get_submodule ( submodule_name )
module_key + = " . " + submodule_name
submodule_name = key_parts . pop ( 0 )
except :
submodule_name + = " _ " + key_parts . pop ( 0 )
module = module . get_submodule ( submodule_name )
2023-06-26 00:57:33 +00:00
module_key = ( module_key + " . " + submodule_name ) . lstrip ( " . " )
2023-05-29 22:11:00 +00:00
return ( module_key , module )
@staticmethod
def _lora_forward_hook (
2023-07-24 06:58:24 +00:00
applied_loras : List [ Tuple [ LoRAModel , float ] ] ,
2023-05-29 22:11:00 +00:00
layer_name : str ,
) :
def lora_forward ( module , input_h , output ) :
if len ( applied_loras ) == 0 :
return output
for lora , weight in applied_loras :
layer = lora . layers . get ( layer_name , None )
if layer is None :
continue
output + = layer . forward ( module , input_h , weight )
return output
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet (
cls ,
unet : UNet2DConditionModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( unet , loras , " lora_unet_ " ) :
yield
@classmethod
@contextmanager
def apply_lora_text_encoder (
cls ,
text_encoder : CLIPTextModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( text_encoder , loras , " lora_te_ " ) :
yield
2023-07-31 20:18:02 +00:00
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder (
cls ,
text_encoder : CLIPTextModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( text_encoder , loras , " lora_te1_ " ) :
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2 (
cls ,
text_encoder : CLIPTextModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( text_encoder , loras , " lora_te2_ " ) :
yield
2023-05-29 22:11:00 +00:00
@classmethod
@contextmanager
def apply_lora (
cls ,
model : torch . nn . Module ,
2023-07-24 06:58:24 +00:00
loras : List [ Tuple [ LoRAModel , float ] ] ,
2023-05-29 22:11:00 +00:00
prefix : str ,
) :
2023-06-26 00:57:33 +00:00
original_weights = dict ( )
2023-05-29 22:11:00 +00:00
try :
2023-07-05 04:39:15 +00:00
with torch . no_grad ( ) :
2023-06-26 00:57:33 +00:00
for lora , lora_weight in loras :
2023-07-28 13:46:44 +00:00
# assert lora.device.type == "cpu"
2023-06-26 00:57:33 +00:00
for layer_key , layer in lora . layers . items ( ) :
if not layer_key . startswith ( prefix ) :
continue
2023-05-29 22:11:00 +00:00
2023-06-26 00:57:33 +00:00
module_key , module = cls . _resolve_lora_key ( model , layer_key , prefix )
if module_key not in original_weights :
original_weights [ module_key ] = module . weight . detach ( ) . to ( device = " cpu " , copy = True )
# enable autocast to calc fp16 loras on cpu
2023-07-28 13:46:44 +00:00
# with torch.autocast(device_type="cpu"):
2023-07-05 02:37:16 +00:00
layer . to ( dtype = torch . float32 )
layer_scale = layer . alpha / layer . rank if ( layer . alpha and layer . rank ) else 1.0
layer_weight = layer . get_weight ( ) * lora_weight * layer_scale
2023-06-26 00:57:33 +00:00
if module . weight . shape != layer_weight . shape :
# TODO: debug on lycoris
layer_weight = layer_weight . reshape ( module . weight . shape )
module . weight + = layer_weight . to ( device = module . weight . device , dtype = module . weight . dtype )
2023-05-29 22:11:00 +00:00
2023-07-28 13:46:44 +00:00
yield # wait for context manager exit
2023-05-29 22:11:00 +00:00
finally :
2023-07-05 04:39:15 +00:00
with torch . no_grad ( ) :
2023-06-26 00:57:33 +00:00
for module_key , weight in original_weights . items ( ) :
model . get_submodule ( module_key ) . weight . copy_ ( weight )
2023-05-30 23:12:27 +00:00
@classmethod
@contextmanager
def apply_ti (
cls ,
tokenizer : CLIPTokenizer ,
text_encoder : CLIPTextModel ,
ti_list : List [ Any ] ,
) - > Tuple [ CLIPTokenizer , TextualInversionManager ] :
init_tokens_count = None
new_tokens_added = None
try :
ti_tokenizer = copy . deepcopy ( tokenizer )
2023-06-17 16:20:24 +00:00
ti_manager = TextualInversionManager ( ti_tokenizer )
2023-05-30 23:12:27 +00:00
init_tokens_count = text_encoder . resize_token_embeddings ( None ) . num_embeddings
def _get_trigger ( ti , index ) :
trigger = ti . name
if index > 0 :
trigger + = f " -!pad- { i } "
return f " < { trigger } > "
# modify tokenizer
new_tokens_added = 0
for ti in ti_list :
for i in range ( ti . embedding . shape [ 0 ] ) :
new_tokens_added + = ti_tokenizer . add_tokens ( _get_trigger ( ti , i ) )
# modify text_encoder
text_encoder . resize_token_embeddings ( init_tokens_count + new_tokens_added )
model_embeddings = text_encoder . get_input_embeddings ( )
for ti in ti_list :
ti_tokens = [ ]
for i in range ( ti . embedding . shape [ 0 ] ) :
embedding = ti . embedding [ i ]
trigger = _get_trigger ( ti , i )
token_id = ti_tokenizer . convert_tokens_to_ids ( trigger )
if token_id == ti_tokenizer . unk_token_id :
raise RuntimeError ( f " Unable to find token id for token ' { trigger } ' " )
if model_embeddings . weight . data [ token_id ] . shape != embedding . shape :
raise ValueError (
f " Cannot load embedding for { trigger } . It was trained on a model with token dimension { embedding . shape [ 0 ] } , but the current model has token dimension { model_embeddings . weight . data [ token_id ] . shape [ 0 ] } . "
)
2023-07-28 13:46:44 +00:00
model_embeddings . weight . data [ token_id ] = embedding . to (
device = text_encoder . device , dtype = text_encoder . dtype
)
2023-05-30 23:12:27 +00:00
ti_tokens . append ( token_id )
if len ( ti_tokens ) > 1 :
ti_manager . pad_tokens [ ti_tokens [ 0 ] ] = ti_tokens [ 1 : ]
yield ti_tokenizer , ti_manager
finally :
if init_tokens_count and new_tokens_added :
text_encoder . resize_token_embeddings ( init_tokens_count )
2023-07-06 13:09:40 +00:00
@classmethod
@contextmanager
def apply_clip_skip (
cls ,
text_encoder : CLIPTextModel ,
clip_skip : int ,
) :
skipped_layers = [ ]
try :
for i in range ( clip_skip ) :
skipped_layers . append ( text_encoder . text_model . encoder . layers . pop ( - 1 ) )
yield
finally :
while len ( skipped_layers ) > 0 :
text_encoder . text_model . encoder . layers . append ( skipped_layers . pop ( ) )
2023-07-28 13:46:44 +00:00
2023-05-30 23:12:27 +00:00
class TextualInversionModel :
name : str
2023-07-28 13:46:44 +00:00
embedding : torch . Tensor # [n, 768]|[n, 1280]
2023-05-30 23:12:27 +00:00
@classmethod
def from_checkpoint (
cls ,
file_path : Union [ str , Path ] ,
device : Optional [ torch . device ] = None ,
dtype : Optional [ torch . dtype ] = None ,
) :
if not isinstance ( file_path , Path ) :
file_path = Path ( file_path )
2023-07-28 13:46:44 +00:00
result = cls ( ) # TODO:
result . name = file_path . stem # TODO:
2023-05-30 23:12:27 +00:00
if file_path . suffix == " .safetensors " :
state_dict = load_file ( file_path . absolute ( ) . as_posix ( ) , device = " cpu " )
else :
state_dict = torch . load ( file_path , map_location = " cpu " )
# both v1 and v2 format embeddings
# difference mostly in metadata
if " string_to_param " in state_dict :
if len ( state_dict [ " string_to_param " ] ) > 1 :
2023-07-28 13:46:44 +00:00
print (
f ' Warn: Embedding " { file_path . name } " contains multiple tokens, which is not supported. The first token will be used. '
)
2023-05-30 23:12:27 +00:00
result . embedding = next ( iter ( state_dict [ " string_to_param " ] . values ( ) ) )
# v3 (easynegative)
elif " emb_params " in state_dict :
result . embedding = state_dict [ " emb_params " ]
# v4(diffusers bin files)
else :
result . embedding = next ( iter ( state_dict . values ( ) ) )
2023-07-05 16:46:00 +00:00
if len ( result . embedding . shape ) == 1 :
result . embedding = result . embedding . unsqueeze ( 0 )
2023-05-30 23:12:27 +00:00
if not isinstance ( result . embedding , torch . Tensor ) :
raise ValueError ( f " Invalid embeddings file: { file_path . name } " )
return result
class TextualInversionManager ( BaseTextualInversionManager ) :
pad_tokens : Dict [ int , List [ int ] ]
2023-06-17 16:20:24 +00:00
tokenizer : CLIPTokenizer
2023-05-30 23:12:27 +00:00
2023-06-17 16:20:24 +00:00
def __init__ ( self , tokenizer : CLIPTokenizer ) :
2023-05-30 23:12:27 +00:00
self . pad_tokens = dict ( )
2023-06-17 16:20:24 +00:00
self . tokenizer = tokenizer
2023-05-30 23:12:27 +00:00
2023-07-28 13:46:44 +00:00
def expand_textual_inversion_token_ids_if_necessary ( self , token_ids : list [ int ] ) - > list [ int ] :
2023-05-30 23:12:27 +00:00
if len ( self . pad_tokens ) == 0 :
return token_ids
2023-06-17 16:20:24 +00:00
if token_ids [ 0 ] == self . tokenizer . bos_token_id :
raise ValueError ( " token_ids must not start with bos_token_id " )
if token_ids [ - 1 ] == self . tokenizer . eos_token_id :
raise ValueError ( " token_ids must not end with eos_token_id " )
2023-05-30 23:12:27 +00:00
new_token_ids = [ ]
for token_id in token_ids :
new_token_ids . append ( token_id )
if token_id in self . pad_tokens :
new_token_ids . extend ( self . pad_tokens [ token_id ] )
return new_token_ids
2023-06-20 23:12:21 +00:00
class ONNXModelPatcher :
2023-07-28 13:59:35 +00:00
from . models . base import IAIOnnxRuntimeModel , OnnxRuntimeModel
2023-07-28 14:00:09 +00:00
2023-06-20 23:12:21 +00:00
@classmethod
@contextmanager
def apply_lora_unet (
cls ,
unet : OnnxRuntimeModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( unet , loras , " lora_unet_ " ) :
yield
@classmethod
@contextmanager
def apply_lora_text_encoder (
cls ,
text_encoder : OnnxRuntimeModel ,
loras : List [ Tuple [ LoRAModel , float ] ] ,
) :
with cls . apply_lora ( text_encoder , loras , " lora_te_ " ) :
yield
2023-06-21 01:24:25 +00:00
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
2023-06-20 23:12:21 +00:00
@classmethod
@contextmanager
def apply_lora (
cls ,
model : IAIOnnxRuntimeModel ,
loras : List [ Tuple [ LoraModel , float ] ] ,
prefix : str ,
) :
from . models . base import IAIOnnxRuntimeModel
2023-07-28 13:46:44 +00:00
2023-06-20 23:12:21 +00:00
if not isinstance ( model , IAIOnnxRuntimeModel ) :
raise Exception ( " Only IAIOnnxRuntimeModel models supported " )
2023-06-22 17:03:17 +00:00
orig_weights = dict ( )
2023-06-20 23:12:21 +00:00
try :
blended_loras = dict ( )
for lora , lora_weight in loras :
for layer_key , layer in lora . layers . items ( ) :
if not layer_key . startswith ( prefix ) :
continue
2023-07-20 18:02:23 +00:00
layer . to ( dtype = torch . float32 )
2023-06-20 23:12:21 +00:00
layer_key = layer_key . replace ( prefix , " " )
layer_weight = layer . get_weight ( ) . detach ( ) . cpu ( ) . numpy ( ) * lora_weight
if layer_key is blended_loras :
blended_loras [ layer_key ] + = layer_weight
else :
blended_loras [ layer_key ] = layer_weight
2023-06-22 17:03:17 +00:00
node_names = dict ( )
for node in model . nodes . values ( ) :
node_names [ node . name . replace ( " / " , " _ " ) . replace ( " . " , " _ " ) . lstrip ( " _ " ) ] = node . name
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
for layer_key , lora_weight in blended_loras . items ( ) :
2023-06-20 23:12:21 +00:00
conv_key = layer_key + " _Conv "
gemm_key = layer_key + " _Gemm "
matmul_key = layer_key + " _MatMul "
2023-06-22 17:03:17 +00:00
if conv_key in node_names or gemm_key in node_names :
if conv_key in node_names :
conv_node = model . nodes [ node_names [ conv_key ] ]
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
conv_node = model . nodes [ node_names [ gemm_key ] ]
2023-06-20 23:12:21 +00:00
weight_name = [ n for n in conv_node . input if " .weight " in n ] [ 0 ]
2023-06-22 17:03:17 +00:00
orig_weight = model . tensors [ weight_name ]
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
if orig_weight . shape [ - 2 : ] == ( 1 , 1 ) :
if lora_weight . shape [ - 2 : ] == ( 1 , 1 ) :
new_weight = orig_weight . squeeze ( ( 3 , 2 ) ) + lora_weight . squeeze ( ( 3 , 2 ) )
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
new_weight = orig_weight . squeeze ( ( 3 , 2 ) ) + lora_weight
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
new_weight = np . expand_dims ( new_weight , ( 2 , 3 ) )
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
if orig_weight . shape != lora_weight . shape :
new_weight = orig_weight + lora_weight . reshape ( orig_weight . shape )
2023-06-20 23:12:21 +00:00
else :
2023-06-22 17:03:17 +00:00
new_weight = orig_weight + lora_weight
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
orig_weights [ weight_name ] = orig_weight
model . tensors [ weight_name ] = new_weight . astype ( orig_weight . dtype )
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
elif matmul_key in node_names :
weight_node = model . nodes [ node_names [ matmul_key ] ]
2023-06-20 23:12:21 +00:00
matmul_name = [ n for n in weight_node . input if " MatMul " in n ] [ 0 ]
2023-06-22 17:03:17 +00:00
orig_weight = model . tensors [ matmul_name ]
new_weight = orig_weight + lora_weight . transpose ( )
2023-06-20 23:12:21 +00:00
2023-06-22 17:03:17 +00:00
orig_weights [ matmul_name ] = orig_weight
model . tensors [ matmul_name ] = new_weight . astype ( orig_weight . dtype )
2023-06-20 23:12:21 +00:00
else :
# warn? err?
pass
yield
finally :
# restore original weights
2023-06-22 17:03:17 +00:00
for name , orig_weight in orig_weights . items ( ) :
model . tensors [ name ] = orig_weight
2023-06-20 23:12:21 +00:00
@classmethod
@contextmanager
def apply_ti (
cls ,
tokenizer : CLIPTokenizer ,
text_encoder : IAIOnnxRuntimeModel ,
ti_list : List [ Any ] ,
) - > Tuple [ CLIPTokenizer , TextualInversionManager ] :
from . models . base import IAIOnnxRuntimeModel
2023-07-28 13:46:44 +00:00
2023-06-20 23:12:21 +00:00
if not isinstance ( text_encoder , IAIOnnxRuntimeModel ) :
raise Exception ( " Only IAIOnnxRuntimeModel models supported " )
2023-06-22 17:03:17 +00:00
orig_embeddings = None
2023-06-20 23:12:21 +00:00
try :
ti_tokenizer = copy . deepcopy ( tokenizer )
ti_manager = TextualInversionManager ( ti_tokenizer )
def _get_trigger ( ti , index ) :
trigger = ti . name
if index > 0 :
trigger + = f " -!pad- { i } "
return f " < { trigger } > "
# modify tokenizer
new_tokens_added = 0
for ti in ti_list :
for i in range ( ti . embedding . shape [ 0 ] ) :
new_tokens_added + = ti_tokenizer . add_tokens ( _get_trigger ( ti , i ) )
# modify text_encoder
2023-06-22 17:03:17 +00:00
orig_embeddings = text_encoder . tensors [ " text_model.embeddings.token_embedding.weight " ]
embeddings = np . concatenate (
2023-07-28 13:46:44 +00:00
( np . copy ( orig_embeddings ) , np . zeros ( ( new_tokens_added , orig_embeddings . shape [ 1 ] ) ) ) ,
2023-06-22 17:03:17 +00:00
axis = 0 ,
)
2023-06-20 23:12:21 +00:00
for ti in ti_list :
ti_tokens = [ ]
for i in range ( ti . embedding . shape [ 0 ] ) :
embedding = ti . embedding [ i ] . detach ( ) . numpy ( )
trigger = _get_trigger ( ti , i )
token_id = ti_tokenizer . convert_tokens_to_ids ( trigger )
if token_id == ti_tokenizer . unk_token_id :
raise RuntimeError ( f " Unable to find token id for token ' { trigger } ' " )
2023-06-22 17:03:17 +00:00
if embeddings [ token_id ] . shape != embedding . shape :
2023-06-20 23:12:21 +00:00
raise ValueError (
2023-06-22 17:03:17 +00:00
f " Cannot load embedding for { trigger } . It was trained on a model with token dimension { embedding . shape [ 0 ] } , but the current model has token dimension { embeddings [ token_id ] . shape [ 0 ] } . "
2023-06-20 23:12:21 +00:00
)
2023-06-22 17:03:17 +00:00
embeddings [ token_id ] = embedding
2023-06-20 23:12:21 +00:00
ti_tokens . append ( token_id )
if len ( ti_tokens ) > 1 :
ti_manager . pad_tokens [ ti_tokens [ 0 ] ] = ti_tokens [ 1 : ]
2023-07-28 13:46:44 +00:00
text_encoder . tensors [ " text_model.embeddings.token_embedding.weight " ] = embeddings . astype (
orig_embeddings . dtype
)
2023-06-20 23:12:21 +00:00
yield ti_tokenizer , ti_manager
finally :
# restore
2023-06-22 17:03:17 +00:00
if orig_embeddings is not None :
text_encoder . tensors [ " text_model.embeddings.token_embedding.weight " ] = orig_embeddings