Tidy spandrel model probe logic, and document the reasons behind the current implementation.

This commit is contained in:
Ryan Dick 2024-07-02 09:51:51 -04:00
parent 9328c17ded
commit 1ab20f43c8

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch
import spandrel
import torch
from picklescan.scanner import scan_file_path
@ -242,15 +243,19 @@ class ModelProbe(object):
return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel.
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
try:
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
_ = SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage
except Exception as e:
# TODO(ryand): Catch a more specific exception type here if we can.
# TODO(ryand): Delete this print statement.
print(e)
except spandrel.UnsupportedModelError:
pass
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")