mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for SDXL textual inversion/embeddings
This commit is contained in:
parent
a8ef4e5be8
commit
0719a46372
@ -192,10 +192,19 @@ class ModelPatcher:
|
|||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
|
def _get_ti_embedding(model_embeddings, ti):
|
||||||
|
# for SDXL models, select the embedding that matches the text encoder's dimensions
|
||||||
|
if ti.embedding_2 is not None:
|
||||||
|
return ti.embedding_2 if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] else ti.embedding
|
||||||
|
else:
|
||||||
|
return ti.embedding
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti_name, ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
||||||
|
|
||||||
|
for i in range(ti_embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
@ -203,9 +212,10 @@ class ModelPatcher:
|
|||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti_name, ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
|
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti_embedding.shape[0]):
|
||||||
embedding = ti.embedding[i]
|
embedding = ti_embedding[i]
|
||||||
trigger = _get_trigger(ti_name, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
@ -273,6 +283,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
@ -296,7 +307,7 @@ class TextualInversionModel:
|
|||||||
if "string_to_param" in state_dict:
|
if "string_to_param" in state_dict:
|
||||||
if len(state_dict["string_to_param"]) > 1:
|
if len(state_dict["string_to_param"]) > 1:
|
||||||
print(
|
print(
|
||||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
|
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
||||||
" token will be used."
|
" token will be used."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -306,6 +317,11 @@ class TextualInversionModel:
|
|||||||
elif "emb_params" in state_dict:
|
elif "emb_params" in state_dict:
|
||||||
result.embedding = state_dict["emb_params"]
|
result.embedding = state_dict["emb_params"]
|
||||||
|
|
||||||
|
# v5(sdxl safetensors file)
|
||||||
|
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
||||||
|
result.embedding = state_dict["clip_g"]
|
||||||
|
result.embedding_2 = state_dict["clip_l"]
|
||||||
|
|
||||||
# v4(diffusers bin files)
|
# v4(diffusers bin files)
|
||||||
else:
|
else:
|
||||||
result.embedding = next(iter(state_dict.values()))
|
result.embedding = next(iter(state_dict.values()))
|
||||||
@ -316,6 +332,7 @@ class TextualInversionModel:
|
|||||||
if not isinstance(result.embedding, torch.Tensor):
|
if not isinstance(result.embedding, torch.Tensor):
|
||||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||||
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -342,6 +359,13 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if token_id in self.pad_tokens:
|
if token_id in self.pad_tokens:
|
||||||
new_token_ids.extend(self.pad_tokens[token_id])
|
new_token_ids.extend(self.pad_tokens[token_id])
|
||||||
|
|
||||||
|
# Do not exceed the max model input size
|
||||||
|
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||||
|
# which first removes and then adds back the start and end tokens.
|
||||||
|
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||||
|
if len(new_token_ids) > max_length:
|
||||||
|
new_token_ids = new_token_ids[0:max_length]
|
||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
|
||||||
|
|
||||||
@ -490,14 +514,20 @@ class ONNXModelPatcher:
|
|||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
|
# modify text_encoder
|
||||||
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti_name, ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
|
||||||
|
|
||||||
# modify text_encoder
|
if ti.embedding_2 is not None:
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
ti_embedding = ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
||||||
|
else:
|
||||||
|
ti_embedding = ti.embedding
|
||||||
|
|
||||||
|
for i in range(ti_embedding.shape[0]):
|
||||||
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
embeddings = np.concatenate(
|
embeddings = np.concatenate(
|
||||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
||||||
@ -506,8 +536,8 @@ class ONNXModelPatcher:
|
|||||||
|
|
||||||
for ti_name, ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti_embedding.shape[0]):
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
embedding = ti_embedding[i].detach().numpy()
|
||||||
trigger = _get_trigger(ti_name, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
|
@ -373,12 +373,16 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|||||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||||
elif "emb_params" in checkpoint:
|
elif "emb_params" in checkpoint:
|
||||||
token_dim = checkpoint["emb_params"].shape[-1]
|
token_dim = checkpoint["emb_params"].shape[-1]
|
||||||
|
elif "clip_g" in checkpoint:
|
||||||
|
token_dim = checkpoint["clip_g"].shape[-1]
|
||||||
else:
|
else:
|
||||||
token_dim = list(checkpoint.values())[0].shape[0]
|
token_dim = list(checkpoint.values())[0].shape[0]
|
||||||
if token_dim == 768:
|
if token_dim == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif token_dim == 1024:
|
elif token_dim == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif token_dim == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user