mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add ability to retrieve current list of embedding trigger strings
This PR adds a new attributer to ldm.generate, `embedding_trigger_strings`: ``` gen = Generate(...) strings = gen.embedding_trigger_strings strings = gen.embedding_trigger_strings() ``` The trigger strings will change when the model is updated to show only those strings which are compatible with the current model. Dynamically-downloaded triggers from the HF Concepts Library will only show up after they are used for the first time. However, the full list of concepts available for download can be retrieved programatically like this: ``` from ldm.invoke.concepts_lib import HuggingFAceConceptsLibrary concepts = HuggingFaceConceptsLibrary() trigger_strings = concepts.list_concepts() ```
This commit is contained in:
parent
628e82fa79
commit
bc18a94d8c
978
ldm/generate.py
978
ldm/generate.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -751,6 +751,9 @@ class Args(object):
|
||||
!fix applies upscaling/facefixing to a previously-generated image.
|
||||
invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
|
||||
|
||||
*embeddings*
|
||||
invoke> !triggers -- return all trigger phrases contained in loaded embedding files
|
||||
|
||||
*History manipulation*
|
||||
!fetch retrieves the command used to generate an earlier image. Provide
|
||||
a directory wildcard and the name of a file to write and all the commands
|
||||
|
@ -60,7 +60,7 @@ COMMANDS = (
|
||||
'--text_mask','-tm',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
'!mask','!triggers',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@ -21,11 +22,14 @@ class TextualInversion:
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
class TextualInversionManager():
|
||||
def __init__(self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool=True):
|
||||
|
||||
class TextualInversionManager:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
@ -38,47 +42,60 @@ class TextualInversionManager():
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||
if str(ckpt_path).endswith('.DS_Store'):
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False):
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
print(
|
||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
except Exception:
|
||||
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||
ckpt_path = Path(ckpt_path)
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model"
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
try:
|
||||
self._add_textual_inversion(embedding_info['name'],
|
||||
embedding_info['embedding'],
|
||||
defer_injecting_tokens=defer_injecting_tokens)
|
||||
self._add_textual_inversion(
|
||||
embedding_info["name"],
|
||||
embedding_info["embedding"],
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f' | The error was {str(e)}')
|
||||
print(f" | The error was {str(e)}")
|
||||
else:
|
||||
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||
print(
|
||||
f">> Failed to load embedding located at {ckpt_path}. Unsupported file."
|
||||
)
|
||||
|
||||
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> TextualInversion:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
@ -86,46 +103,59 @@ class TextualInversionManager():
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
|
||||
print(
|
||||
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
|
||||
raise ValueError(
|
||||
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(
|
||||
trigger_string=trigger_str,
|
||||
embedding=embedding
|
||||
)
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith('Warning'):
|
||||
if str(e).startswith("Warning"):
|
||||
print(f">> {str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
|
||||
print(
|
||||
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)]
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
@ -133,7 +163,6 @@ class TextualInversionManager():
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
@ -141,32 +170,43 @@ class TextualInversionManager():
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
|
||||
print(
|
||||
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}')
|
||||
print(f' | The error was {str(e)}')
|
||||
print(
|
||||
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
print(f" | The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
@ -181,20 +221,31 @@ class TextualInversionManager():
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
|
||||
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
@ -207,66 +258,78 @@ class TextualInversionManager():
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split('.')[-1]
|
||||
if file_type == 'pt':
|
||||
file_type = embedding_file.split(".")[-1]
|
||||
if file_type == "pt":
|
||||
return self._parse_embedding_pt(embedding_file)
|
||||
elif file_type == 'bin':
|
||||
elif file_type == "bin":
|
||||
return self._parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||
print(f">> Not a recognized embedding file: {embedding_file}")
|
||||
|
||||
def _parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||
|
||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||
if embedding_info['num_of_embeddings'] > 1:
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(">> More than 1 embedding found. Will use the first one")
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
except (AttributeError, KeyError):
|
||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||
embedding_info['token_dim'] = embedding.size()[1]
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys())==1:
|
||||
print('>> Detected .bin file masquerading as .pt file')
|
||||
elif len(embedding_ckpt.keys()) == 1:
|
||||
print(">> Detected .bin file masquerading as .pt file")
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
@ -274,27 +337,45 @@ class TextualInversionManager():
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt[token]
|
||||
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
embedding_info["name"] = token or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info[
|
||||
"num_vectors_per_token"
|
||||
] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||
'''
|
||||
def _handle_broken_pt_variants(
|
||||
self, embedding_ckpt: dict, embedding_file: str
|
||||
) -> dict:
|
||||
"""
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
'''
|
||||
"""
|
||||
embedding_info = {}
|
||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt['string_to_token'].keys()):
|
||||
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
print(
|
||||
">> Detected .pt file variant 1"
|
||||
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
if token != "*"
|
||||
else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[
|
||||
"string_to_param"
|
||||
].state_dict()[token]
|
||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||
"embedding"
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
Loading…
Reference in New Issue
Block a user