mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix ruff format check
This commit is contained in:
parent
0719a46372
commit
f95ce1870c
@ -195,7 +195,11 @@ class ModelPatcher:
|
||||
def _get_ti_embedding(model_embeddings, ti):
|
||||
# for SDXL models, select the embedding that matches the text encoder's dimensions
|
||||
if ti.embedding_2 is not None:
|
||||
return ti.embedding_2 if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] else ti.embedding
|
||||
return (
|
||||
ti.embedding_2
|
||||
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
|
||||
else ti.embedding
|
||||
)
|
||||
else:
|
||||
return ti.embedding
|
||||
|
||||
@ -212,7 +216,6 @@ class ModelPatcher:
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti_name, ti in ti_list:
|
||||
|
||||
ti_tokens = []
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
embedding = ti_embedding[i]
|
||||
@ -282,7 +285,7 @@ class ModelPatcher:
|
||||
|
||||
|
||||
class TextualInversionModel:
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
@classmethod
|
||||
@ -308,7 +311,7 @@ class TextualInversionModel:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
||||
" token will be used."
|
||||
" token will be used.",
|
||||
)
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
@ -319,8 +322,8 @@ class TextualInversionModel:
|
||||
|
||||
# v5(sdxl safetensors file)
|
||||
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
||||
result.embedding = state_dict["clip_g"]
|
||||
result.embedding_2 = state_dict["clip_l"]
|
||||
result.embedding = state_dict["clip_g"]
|
||||
result.embedding_2 = state_dict["clip_l"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
@ -332,7 +335,6 @@ class TextualInversionModel:
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -520,9 +522,10 @@ class ONNXModelPatcher:
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti_name, ti in ti_list:
|
||||
|
||||
if ti.embedding_2 is not None:
|
||||
ti_embedding = ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
||||
ti_embedding = (
|
||||
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
||||
)
|
||||
else:
|
||||
ti_embedding = ti.embedding
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user