mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
don't even try to load incompatible embeddings
This commit is contained in:
parent
bc18a94d8c
commit
e29399e032
@ -2,7 +2,7 @@ import os
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
@ -59,11 +59,12 @@ class TextualInversionManager:
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False):
|
||||
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
scan_result = scan_file_path(str(ckpt_path))
|
||||
if scan_result.infected_files == 1:
|
||||
print(
|
||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||
@ -71,13 +72,22 @@ class TextualInversionManager:
|
||||
print("### For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
except Exception:
|
||||
ckpt_path = Path(ckpt_path)
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model"
|
||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
!= embedding_info["embedding"].shape[0]
|
||||
):
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
@ -90,7 +100,7 @@ class TextualInversionManager:
|
||||
print(f" | The error was {str(e)}")
|
||||
else:
|
||||
print(
|
||||
f">> Failed to load embedding located at {ckpt_path}. Unsupported file."
|
||||
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
|
||||
)
|
||||
|
||||
def _add_textual_inversion(
|
||||
|
Loading…
Reference in New Issue
Block a user