Re-enable classification API as fallback

This commit is contained in:
Billy
2025-05-16 16:01:40 +10:00
committed by psychedelicious
parent 19ecdb196e
commit a17e771eba
6 changed files with 90 additions and 39 deletions

View File

@ -40,6 +40,7 @@ from invokeai.backend.model_manager.config import (
InvalidModelConfigException,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.config import ModelConfigBase
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
@ -646,14 +647,10 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
# See commit message for details.
# try:
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
# except InvalidModelConfigException:
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@ -146,33 +146,39 @@ class ModelConfigBase(ABC, BaseModel):
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
_USING_LEGACY_PROBE: ClassVar[set] = set()
_USING_CLASSIFY_API: ClassVar[set] = set()
USING_LEGACY_PROBE: ClassVar[set] = set()
USING_CLASSIFY_API: ClassVar[set] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase._USING_CLASSIFY_API.add(cls)
ModelConfigBase.USING_CLASSIFY_API.add(cls)
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
def classify(
mod: str | Path | ModelOnDisk,
hash_algo: HASHING_ALGORITHMS = "blake3_single",
**overrides
):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
if isinstance(mod, Path | str):
mod = ModelOnDisk(mod, hash_algo)
candidates = ModelConfigBase.USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:

View File

@ -4,6 +4,7 @@ from typing import Any, Optional, TypeAlias
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from safetensors import safe_open
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
@ -35,12 +36,21 @@ class ModelOnDisk:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
def weight_files(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
try:
with safe_open(self.path, framework="pt", device="cpu") as f:
metadata = f.metadata()
assert isinstance(metadata, dict)
return metadata
except Exception:
return {}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None
@ -64,18 +74,7 @@ class ModelOnDisk:
if path in sd_cache:
return sd_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
path = self.resolve_weight_file(path)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
@ -94,3 +93,18 @@ class ModelOnDisk:
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
if not path:
weight_files = list(self.weight_files())
match weight_files:
case []:
raise ValueError("No weight files found for this model")
case [p]:
return p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
return path

View File

@ -28,10 +28,9 @@ args = parser.parse_args()
def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
try:
return ModelConfigBase.classify(path, hash_algo)
except InvalidModelConfigException:
return ModelProbe.probe(path, hash_algo=hash_algo)
except InvalidModelConfigException:
return ModelConfigBase.classify(path, hash_algo)
for path in args.model_path:
try:

View File

@ -18,13 +18,16 @@ import json
import shutil
import sys
from pathlib import Path
from typing import Optional
import humanize
import torch
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
from invokeai.backend.model_manager.search import ModelSearch
METADATA_KEY = "metadata_key_for_stripped_models"
def strip(v):
match v:
@ -57,9 +60,22 @@ def dress(v):
def load_stripped_model(path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
contents.pop(METADATA_KEY, None)
return dress(contents)
class StrippedModelOnDisk(ModelOnDisk):
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
path = self.resolve_weight_file(path)
return load_stripped_model(path)
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = self.resolve_weight_file(path)
with open(path, "r") as f:
contents = json.load(f)
return contents.get(METADATA_KEY, {})
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
original = ModelOnDisk(original_model_path)
if original.path.is_file():
@ -69,11 +85,14 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
stripped = ModelOnDisk(stripped_model_path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
for component_path in stripped.weight_files():
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
metadata = stripped.metadata()
contents = {**stripped_state_dict, METADATA_KEY: metadata}
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)
json.dump(contents, f, indent=4)
before_size = humanize.naturalsize(original.size())
after_size = humanize.naturalsize(stripped.size())

View File

@ -29,6 +29,9 @@ from invokeai.backend.model_manager.legacy_probe import (
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util.logging import InvokeAILogger
from scripts.strip_models import StrippedModelOnDisk
logger = InvokeAILogger.get_logger(__file__)
@pytest.mark.parametrize(
@ -156,7 +159,8 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
pass
try:
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
stripped_mod = StrippedModelOnDisk(path)
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
except InvalidModelConfigException:
pass
@ -165,10 +169,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
assert legacy_config.model_dump_json() == new_config.model_dump_json()
elif legacy_config:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE
elif new_config:
assert type(new_config) in ModelConfigBase._USING_CLASSIFY_API
assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API
else:
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
@ -177,7 +181,6 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
configs_with_tests.add(config_type)
untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample}
logger = InvokeAILogger.get_logger(__file__)
logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}")
@ -255,3 +258,16 @@ def test_any_model_config_includes_all_config_classes():
expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample}
assert extracted == expected
def test_config_uniquely_matches_model(datadir: Path):
model_paths = ModelSearch().search(datadir / "stripped_models")
for path in model_paths:
mod = StrippedModelOnDisk(path)
matches = { cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod) }
assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}"
if not matches:
logger.warning(f"Model at path {path} does not match any config classes using classify API.")