fix ruff format check

This commit is contained in:
Lincoln Stein 2023-12-01 01:46:12 -05:00
parent 0719a46372
commit f95ce1870c

View File

@ -195,7 +195,11 @@ class ModelPatcher:
def _get_ti_embedding(model_embeddings, ti): def _get_ti_embedding(model_embeddings, ti):
# for SDXL models, select the embedding that matches the text encoder's dimensions # for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None: if ti.embedding_2 is not None:
return ti.embedding_2 if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] else ti.embedding return (
ti.embedding_2
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
else ti.embedding
)
else: else:
return ti.embedding return ti.embedding
@ -212,7 +216,6 @@ class ModelPatcher:
model_embeddings = text_encoder.get_input_embeddings() model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list: for ti_name, ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti_embedding.shape[0]): for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i] embedding = ti_embedding[i]
@ -308,7 +311,7 @@ class TextualInversionModel:
if len(state_dict["string_to_param"]) > 1: if len(state_dict["string_to_param"]) > 1:
print( print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used." " token will be used.",
) )
result.embedding = next(iter(state_dict["string_to_param"].values())) result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -332,7 +335,6 @@ class TextualInversionModel:
if not isinstance(result.embedding, torch.Tensor): if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}") raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result return result
@ -520,9 +522,10 @@ class ONNXModelPatcher:
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti_name, ti in ti_list:
if ti.embedding_2 is not None: if ti.embedding_2 is not None:
ti_embedding = ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding ti_embedding = (
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
)
else: else:
ti_embedding = ti.embedding ti_embedding = ti.embedding