mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
don't even try to load incompatible embeddings
This commit is contained in:
parent
bc18a94d8c
commit
e29399e032
@ -2,7 +2,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
@ -59,11 +59,12 @@ class TextualInversionManager:
|
|||||||
def get_all_trigger_strings(self) -> list[str]:
|
def get_all_trigger_strings(self) -> list[str]:
|
||||||
return [ti.trigger_string for ti in self.textual_inversions]
|
return [ti.trigger_string for ti in self.textual_inversions]
|
||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool = False):
|
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
||||||
|
ckpt_path = Path(ckpt_path)
|
||||||
if str(ckpt_path).endswith(".DS_Store"):
|
if str(ckpt_path).endswith(".DS_Store"):
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
scan_result = scan_file_path(ckpt_path)
|
scan_result = scan_file_path(str(ckpt_path))
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
print(
|
print(
|
||||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
@ -71,13 +72,22 @@ class TextualInversionManager:
|
|||||||
print("### For your safety, InvokeAI will not load this embed.")
|
print("### For your safety, InvokeAI will not load this embed.")
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
ckpt_path = Path(ckpt_path)
|
|
||||||
print(
|
print(
|
||||||
f"** Notice: {ckpt_path.parents[0].stem}/{ckpt_path.stem} is incompatible with this model"
|
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||||
|
!= embedding_info["embedding"].shape[0]
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
embedding_info = self._parse_embedding(ckpt_path)
|
|
||||||
if embedding_info:
|
if embedding_info:
|
||||||
try:
|
try:
|
||||||
self._add_textual_inversion(
|
self._add_textual_inversion(
|
||||||
@ -90,7 +100,7 @@ class TextualInversionManager:
|
|||||||
print(f" | The error was {str(e)}")
|
print(f" | The error was {str(e)}")
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f">> Failed to load embedding located at {ckpt_path}. Unsupported file."
|
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user