mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(model_probe): provide more clues when we fail to load a model.
This commit is contained in:
parent
8611ffe32d
commit
26a7b7b66d
@ -1,12 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from diffusers import ModelMixin, ConfigMixin
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Literal, Union, Dict, Optional
|
from typing import Callable, Literal, Union, Dict, Optional
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from diffusers import ModelMixin, ConfigMixin
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
@ -17,8 +16,8 @@ from .models import (
|
|||||||
SilenceWarnings,
|
SilenceWarnings,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
from .util import lora_token_vector_length
|
|
||||||
from .models.base import read_checkpoint_meta
|
from .models.base import read_checkpoint_meta
|
||||||
|
from .util import lora_token_vector_length
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -171,6 +170,7 @@ class ModelProbe(object):
|
|||||||
Get the model type of a hugging-face style folder.
|
Get the model type of a hugging-face style folder.
|
||||||
"""
|
"""
|
||||||
class_name = None
|
class_name = None
|
||||||
|
error_hint = None
|
||||||
if model:
|
if model:
|
||||||
class_name = model.__class__.__name__
|
class_name = model.__class__.__name__
|
||||||
else:
|
else:
|
||||||
@ -190,12 +190,18 @@ class ModelProbe(object):
|
|||||||
with open(config_path, "r") as file:
|
with open(config_path, "r") as file:
|
||||||
conf = json.load(file)
|
conf = json.load(file)
|
||||||
class_name = conf["_class_name"]
|
class_name = conf["_class_name"]
|
||||||
|
else:
|
||||||
|
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||||
|
|
||||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||||
return type
|
return type
|
||||||
|
else:
|
||||||
|
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||||
|
|
||||||
# give up
|
# give up
|
||||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
raise InvalidModelException(
|
||||||
|
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user