mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(model_probe): provide more clues when we fail to load a model.
This commit is contained in:
parent
8611ffe32d
commit
26a7b7b66d
@ -1,12 +1,11 @@
|
||||
import json
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict, Optional
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from .models import (
|
||||
@ -17,8 +16,8 @@ from .models import (
|
||||
SilenceWarnings,
|
||||
InvalidModelException,
|
||||
)
|
||||
from .util import lora_token_vector_length
|
||||
from .models.base import read_checkpoint_meta
|
||||
from .util import lora_token_vector_length
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -171,6 +170,7 @@ class ModelProbe(object):
|
||||
Get the model type of a hugging-face style folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
@ -190,12 +190,18 @@ class ModelProbe(object):
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
class_name = conf["_class_name"]
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
|
Loading…
Reference in New Issue
Block a user