Tidy spandrel model probe logic, and document the reasons behind the current implementation.

This commit is contained in:
Ryan Dick 2024-07-02 09:51:51 -04:00
parent 9328c17ded
commit 1ab20f43c8

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch import safetensors.torch
import spandrel
import torch import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
@ -242,15 +243,19 @@ class ModelProbe(object):
return ModelType.TextualInversion return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel. # Check if the model can be loaded as a SpandrelImageToImageModel.
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
try: try:
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected. # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt) # explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
_ = SpandrelImageToImageModel.load_from_file(model_path) _ = SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage return ModelType.SpandrelImageToImage
except Exception as e: except spandrel.UnsupportedModelError:
# TODO(ryand): Catch a more specific exception type here if we can.
# TODO(ryand): Delete this print statement.
print(e)
pass pass
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")