mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Re-enable classification API as fallback
This commit is contained in:
@ -40,6 +40,7 @@ from invokeai.backend.model_manager.config import (
|
||||
InvalidModelConfigException,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.config import ModelConfigBase
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
@ -646,14 +647,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
hash_algo = self._app_config.hashing_algorithm
|
||||
fields = config.model_dump()
|
||||
|
||||
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
|
||||
|
||||
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
|
||||
# See commit message for details.
|
||||
# try:
|
||||
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
|
||||
# except InvalidModelConfigException:
|
||||
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
||||
try:
|
||||
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
||||
except InvalidModelConfigException:
|
||||
return ModelConfigBase.classify(model_path, hash_algo, **fields)
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
|
@ -146,33 +146,39 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
)
|
||||
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
|
||||
|
||||
_USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
_USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, LegacyProbeMixin):
|
||||
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
|
||||
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
|
||||
else:
|
||||
ModelConfigBase._USING_CLASSIFY_API.add(cls)
|
||||
ModelConfigBase.USING_CLASSIFY_API.add(cls)
|
||||
|
||||
@staticmethod
|
||||
def all_config_classes():
|
||||
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
|
||||
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
|
||||
concrete = {cls for cls in subclasses if not isabstract(cls)}
|
||||
return concrete
|
||||
|
||||
@staticmethod
|
||||
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
|
||||
def classify(
|
||||
mod: str | Path | ModelOnDisk,
|
||||
hash_algo: HASHING_ALGORITHMS = "blake3_single",
|
||||
**overrides
|
||||
):
|
||||
"""
|
||||
Returns the best matching ModelConfig instance from a model's file/folder path.
|
||||
Raises InvalidModelConfigException if no valid configuration is found.
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
candidates = ModelConfigBase._USING_CLASSIFY_API
|
||||
if isinstance(mod, Path | str):
|
||||
mod = ModelOnDisk(mod, hash_algo)
|
||||
|
||||
candidates = ModelConfigBase.USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
|
||||
mod = ModelOnDisk(model_path, hash_algo)
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
|
@ -4,6 +4,7 @@ from typing import Any, Optional, TypeAlias
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors import safe_open
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
@ -35,12 +36,21 @@ class ModelOnDisk:
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self) -> set[Path]:
|
||||
def weight_files(self) -> set[Path]:
|
||||
if self.path.is_file():
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
|
||||
try:
|
||||
with safe_open(self.path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
assert isinstance(metadata, dict)
|
||||
return metadata
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
||||
if self.path.is_file():
|
||||
return None
|
||||
@ -64,18 +74,7 @@ class ModelOnDisk:
|
||||
if path in sd_cache:
|
||||
return sd_cache[path]
|
||||
|
||||
if not path:
|
||||
components = list(self.component_paths())
|
||||
match components:
|
||||
case []:
|
||||
raise ValueError("No weight files found for this model")
|
||||
case [p]:
|
||||
path = p
|
||||
case ps if len(ps) >= 2:
|
||||
raise ValueError(
|
||||
f"Multiple weight files found for this model: {ps}. "
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
path = self.resolve_weight_file(path)
|
||||
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
@ -94,3 +93,18 @@ class ModelOnDisk:
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
sd_cache[path] = state_dict
|
||||
return state_dict
|
||||
|
||||
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
|
||||
if not path:
|
||||
weight_files = list(self.weight_files())
|
||||
match weight_files:
|
||||
case []:
|
||||
raise ValueError("No weight files found for this model")
|
||||
case [p]:
|
||||
return p
|
||||
case ps if len(ps) >= 2:
|
||||
raise ValueError(
|
||||
f"Multiple weight files found for this model: {ps}. "
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
return path
|
||||
|
@ -28,10 +28,9 @@ args = parser.parse_args()
|
||||
|
||||
def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
|
||||
try:
|
||||
return ModelConfigBase.classify(path, hash_algo)
|
||||
except InvalidModelConfigException:
|
||||
return ModelProbe.probe(path, hash_algo=hash_algo)
|
||||
|
||||
except InvalidModelConfigException:
|
||||
return ModelConfigBase.classify(path, hash_algo)
|
||||
|
||||
for path in args.model_path:
|
||||
try:
|
||||
|
@ -18,13 +18,16 @@ import json
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import humanize
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
METADATA_KEY = "metadata_key_for_stripped_models"
|
||||
|
||||
|
||||
def strip(v):
|
||||
match v:
|
||||
@ -57,9 +60,22 @@ def dress(v):
|
||||
def load_stripped_model(path: Path, *args, **kwargs):
|
||||
with open(path, "r") as f:
|
||||
contents = json.load(f)
|
||||
contents.pop(METADATA_KEY, None)
|
||||
return dress(contents)
|
||||
|
||||
|
||||
class StrippedModelOnDisk(ModelOnDisk):
|
||||
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
|
||||
path = self.resolve_weight_file(path)
|
||||
return load_stripped_model(path)
|
||||
|
||||
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
|
||||
path = self.resolve_weight_file(path)
|
||||
with open(path, "r") as f:
|
||||
contents = json.load(f)
|
||||
return contents.get(METADATA_KEY, {})
|
||||
|
||||
|
||||
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
|
||||
original = ModelOnDisk(original_model_path)
|
||||
if original.path.is_file():
|
||||
@ -69,11 +85,14 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
|
||||
stripped = ModelOnDisk(stripped_model_path)
|
||||
print(f"Created clone of {original.name} at {stripped.path}")
|
||||
|
||||
for component_path in stripped.component_paths():
|
||||
for component_path in stripped.weight_files():
|
||||
original_state_dict = stripped.load_state_dict(component_path)
|
||||
|
||||
stripped_state_dict = strip(original_state_dict) # type: ignore
|
||||
metadata = stripped.metadata()
|
||||
contents = {**stripped_state_dict, METADATA_KEY: metadata}
|
||||
with open(component_path, "w") as f:
|
||||
json.dump(stripped_state_dict, f, indent=4)
|
||||
json.dump(contents, f, indent=4)
|
||||
|
||||
before_size = humanize.naturalsize(original.size())
|
||||
after_size = humanize.naturalsize(stripped.size())
|
||||
|
@ -29,6 +29,9 @@ from invokeai.backend.model_manager.legacy_probe import (
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from scripts.strip_models import StrippedModelOnDisk
|
||||
|
||||
logger = InvokeAILogger.get_logger(__file__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -156,7 +159,8 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
|
||||
pass
|
||||
|
||||
try:
|
||||
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
|
||||
stripped_mod = StrippedModelOnDisk(path)
|
||||
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
|
||||
except InvalidModelConfigException:
|
||||
pass
|
||||
|
||||
@ -165,10 +169,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
|
||||
assert legacy_config.model_dump_json() == new_config.model_dump_json()
|
||||
|
||||
elif legacy_config:
|
||||
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
|
||||
assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE
|
||||
|
||||
elif new_config:
|
||||
assert type(new_config) in ModelConfigBase._USING_CLASSIFY_API
|
||||
assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API
|
||||
|
||||
else:
|
||||
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
|
||||
@ -177,7 +181,6 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
|
||||
configs_with_tests.add(config_type)
|
||||
|
||||
untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample}
|
||||
logger = InvokeAILogger.get_logger(__file__)
|
||||
logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}")
|
||||
|
||||
|
||||
@ -255,3 +258,16 @@ def test_any_model_config_includes_all_config_classes():
|
||||
|
||||
expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample}
|
||||
assert extracted == expected
|
||||
|
||||
|
||||
def test_config_uniquely_matches_model(datadir: Path):
|
||||
|
||||
model_paths = ModelSearch().search(datadir / "stripped_models")
|
||||
for path in model_paths:
|
||||
|
||||
mod = StrippedModelOnDisk(path)
|
||||
matches = { cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod) }
|
||||
assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}"
|
||||
if not matches:
|
||||
logger.warning(f"Model at path {path} does not match any config classes using classify API.")
|
||||
|
||||
|
Reference in New Issue
Block a user