Merge branch 'main' into perf/lowmem_sequential_guidance

This commit is contained in:
Lincoln Stein 2023-02-20 17:15:33 -05:00 committed by GitHub
commit fd27948c36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 7 deletions

View File

@ -964,6 +964,7 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None: if self.embedding_path is not None:
print(f'>> Loading embeddings from {self.embedding_path}')
for root, _, files in os.walk(self.embedding_path): for root, _, files in os.walk(self.embedding_path):
for name in files: for name in files:
ti_path = os.path.join(root, name) ti_path = os.path.join(root, name)
@ -971,7 +972,7 @@ class Generate:
ti_path, defer_injecting_tokens=True ti_path, defer_injecting_tokens=True
) )
print( print(
f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
) )
self.model_name = model_name self.model_name = model_name

View File

@ -441,6 +441,7 @@ class TextualInversionDataset(Dataset):
self.image_paths = [ self.image_paths = [
os.path.join(self.data_root, file_path) os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root) for file_path in os.listdir(self.data_root)
if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
] ]
self.num_images = len(self.image_paths) self.num_images = len(self.image_paths)

View File

@ -61,8 +61,13 @@ class TextualInversionManager:
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False): def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path) ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: try:
scan_result = scan_file_path(str(ckpt_path)) scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
@ -87,7 +92,7 @@ class TextualInversionManager:
!= embedding_info['token_dim'] != embedding_info['token_dim']
): ):
print( print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model." f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
) )
return return
@ -333,7 +338,6 @@ class TextualInversionManager:
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files # They are actually .bin files
elif len(embedding_ckpt.keys()) == 1: elif len(embedding_ckpt.keys()) == 1:
print(">> Detected .bin file masquerading as .pt file")
embedding_info = self._parse_embedding_bin(embedding_file) embedding_info = self._parse_embedding_bin(embedding_file)
else: else:
@ -372,9 +376,6 @@ class TextualInversionManager:
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
print(
">> Detected .pt file variant 1"
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt["string_to_token"].keys()): for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = ( embedding_info["name"] = (
token token
@ -387,7 +388,7 @@ class TextualInversionManager:
embedding_info["num_vectors_per_token"] = embedding_info[ embedding_info["num_vectors_per_token"] = embedding_info[
"embedding" "embedding"
].shape[0] ].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[0] embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(">> Invalid embedding format") print(">> Invalid embedding format")
embedding_info = None embedding_info = None