don't even try to load incompatible embeddings

This commit is contained in:
Lincoln Stein 2023-02-13 17:00:52 -05:00
parent bc18a94d8c
commit e29399e032

View File

@ -2,7 +2,7 @@ import os
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import torch import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
@ -59,11 +59,12 @@ class TextualInversionManager:
def get_all_trigger_strings(self) -> list[str]: def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions] return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False): def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path)
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: try:
scan_result = scan_file_path(ckpt_path) scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
print( print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}" f"\n### Security Issues Found in Model: {scan_result.issues_count}"
@ -71,13 +72,22 @@ class TextualInversionManager:
print("### For your safety, InvokeAI will not load this embed.") print("### For your safety, InvokeAI will not load this embed.")
return return
except Exception: except Exception:
ckpt_path = Path(ckpt_path)
print( print(
f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model" f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path))
if (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["embedding"].shape[0]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
) )
return return
embedding_info = self._parse_embedding(ckpt_path)
if embedding_info: if embedding_info:
try: try:
self._add_textual_inversion( self._add_textual_inversion(
@ -90,7 +100,7 @@ class TextualInversionManager:
print(f" | The error was {str(e)}") print(f" | The error was {str(e)}")
else: else:
print( print(
f">> Failed to load embedding located at {ckpt_path}. Unsupported file." f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
) )
def _add_textual_inversion( def _add_textual_inversion(