diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index f9c40f8386..acd1f6bab6 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -195,7 +195,11 @@ class ModelPatcher: def _get_ti_embedding(model_embeddings, ti): # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: - return ti.embedding_2 if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] else ti.embedding + return ( + ti.embedding_2 + if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] + else ti.embedding + ) else: return ti.embedding @@ -212,7 +216,6 @@ class ModelPatcher: model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: - ti_tokens = [] for i in range(ti_embedding.shape[0]): embedding = ti_embedding[i] @@ -282,7 +285,7 @@ class ModelPatcher: class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding: torch.Tensor # [n, 768]|[n, 1280] embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models @classmethod @@ -308,7 +311,7 @@ class TextualInversionModel: if len(state_dict["string_to_param"]) > 1: print( f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used." + " token will be used.", ) result.embedding = next(iter(state_dict["string_to_param"].values())) @@ -319,8 +322,8 @@ class TextualInversionModel: # v5(sdxl safetensors file) elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] # v4(diffusers bin files) else: @@ -332,7 +335,6 @@ class TextualInversionModel: if not isinstance(result.embedding, torch.Tensor): raise ValueError(f"Invalid embeddings file: {file_path.name}") - return result @@ -520,9 +522,10 @@ class ONNXModelPatcher: # modify tokenizer new_tokens_added = 0 for ti_name, ti in ti_list: - if ti.embedding_2 is not None: - ti_embedding = ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding + ti_embedding = ( + ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding + ) else: ti_embedding = ti.embedding