mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restore ability of textual inversion manager to read .pt files (#2746)
- Fixes longstanding bug in the token vector size code which caused .pt files to be assigned the wrong token vector length. These were then tossed out during directory scanning.
This commit is contained in:
commit
17294bfa55
@ -964,6 +964,7 @@ class Generate:
|
|||||||
|
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
|
print(f'>> Loading embeddings from {self.embedding_path}')
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
for name in files:
|
for name in files:
|
||||||
ti_path = os.path.join(root, name)
|
ti_path = os.path.join(root, name)
|
||||||
@ -971,7 +972,7 @@ class Generate:
|
|||||||
ti_path, defer_injecting_tokens=True
|
ti_path, defer_injecting_tokens=True
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
|
f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
@ -61,8 +61,13 @@ class TextualInversionManager:
|
|||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
||||||
ckpt_path = Path(ckpt_path)
|
ckpt_path = Path(ckpt_path)
|
||||||
|
|
||||||
|
if not ckpt_path.is_file():
|
||||||
|
return
|
||||||
|
|
||||||
if str(ckpt_path).endswith(".DS_Store"):
|
if str(ckpt_path).endswith(".DS_Store"):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
scan_result = scan_file_path(str(ckpt_path))
|
scan_result = scan_file_path(str(ckpt_path))
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -87,7 +92,7 @@ class TextualInversionManager:
|
|||||||
!= embedding_info['token_dim']
|
!= embedding_info['token_dim']
|
||||||
):
|
):
|
||||||
print(
|
print(
|
||||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
|
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -333,7 +338,6 @@ class TextualInversionManager:
|
|||||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||||
# They are actually .bin files
|
# They are actually .bin files
|
||||||
elif len(embedding_ckpt.keys()) == 1:
|
elif len(embedding_ckpt.keys()) == 1:
|
||||||
print(">> Detected .bin file masquerading as .pt file")
|
|
||||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -372,9 +376,6 @@ class TextualInversionManager:
|
|||||||
if isinstance(
|
if isinstance(
|
||||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||||
):
|
):
|
||||||
print(
|
|
||||||
">> Detected .pt file variant 1"
|
|
||||||
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
|
||||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||||
embedding_info["name"] = (
|
embedding_info["name"] = (
|
||||||
token
|
token
|
||||||
@ -387,7 +388,7 @@ class TextualInversionManager:
|
|||||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||||
"embedding"
|
"embedding"
|
||||||
].shape[0]
|
].shape[0]
|
||||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
||||||
else:
|
else:
|
||||||
print(">> Invalid embedding format")
|
print(">> Invalid embedding format")
|
||||||
embedding_info = None
|
embedding_info = None
|
||||||
|
Loading…
Reference in New Issue
Block a user