feat(model_probe): provide more clues when we fail to load a model.

This commit is contained in:
Kevin Turner 2023-08-17 20:01:21 -07:00
parent 8611ffe32d
commit 26a7b7b66d

View File

@ -1,12 +1,11 @@
import json
import torch
import safetensors.torch
from dataclasses import dataclass
from diffusers import ModelMixin, ConfigMixin
from pathlib import Path
from typing import Callable, Literal, Union, Dict, Optional
import safetensors.torch
import torch
from diffusers import ModelMixin, ConfigMixin
from picklescan.scanner import scan_file_path
from .models import (
@ -17,8 +16,8 @@ from .models import (
SilenceWarnings,
InvalidModelException,
)
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass
@ -171,6 +170,7 @@ class ModelProbe(object):
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
@ -190,12 +190,18 @@ class ModelProbe(object):
with open(config_path, "r") as file:
conf = json.load(file)
class_name = conf["_class_name"]
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: