don't even try to load incompatible embeddings

This commit is contained in:
Lincoln Stein 2023-02-13 17:00:52 -05:00
parent bc18a94d8c
commit e29399e032

View File

@ -2,7 +2,7 @@ import os
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import torch
from picklescan.scanner import scan_file_path
@ -59,11 +59,12 @@ class TextualInversionManager:
def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False):
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path)
if str(ckpt_path).endswith(".DS_Store"):
return
try:
scan_result = scan_file_path(ckpt_path)
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
@ -71,13 +72,22 @@ class TextualInversionManager:
print("### For your safety, InvokeAI will not load this embed.")
return
except Exception:
ckpt_path = Path(ckpt_path)
print(
f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model"
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path))
if (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["embedding"].shape[0]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
)
return
embedding_info = self._parse_embedding(ckpt_path)
if embedding_info:
try:
self._add_textual_inversion(
@ -90,7 +100,7 @@ class TextualInversionManager:
print(f" | The error was {str(e)}")
else:
print(
f">> Failed to load embedding located at {ckpt_path}. Unsupported file."
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
)
def _add_textual_inversion(