mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix model probing for controlnet checkpoint legacy config files
This commit is contained in:
parent
19baea1883
commit
ec510d34b5
@ -127,7 +127,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = Path(model_path)
|
||||
metadata = metadata or {}
|
||||
if metadata.get('source') is None:
|
||||
metadata['source'] = model_path.as_posix()
|
||||
metadata['source'] = model_path.resolve().as_posix()
|
||||
return self._register(model_path, metadata)
|
||||
|
||||
def install_path(
|
||||
@ -138,7 +138,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = Path(model_path)
|
||||
metadata = metadata or {}
|
||||
if metadata.get('source') is None:
|
||||
metadata['source'] = model_path.as_posix()
|
||||
metadata['source'] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), metadata)
|
||||
|
||||
@ -366,6 +366,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, 'config'):
|
||||
# make config relative to our root
|
||||
info.config = self.app_config.legacy_conf_dir / info.config
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(key, info)
|
||||
return key
|
||||
|
30
invokeai/backend/model_manager/__init__.py
Normal file
30
invokeai/backend/model_manager/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .probe import ModelProbe
|
||||
from .config import (
|
||||
InvalidModelConfigException,
|
||||
DuplicateModelException,
|
||||
ModelConfigFactory,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
ModelFormat,
|
||||
SchedulerPredictionType,
|
||||
AnyModelConfig,
|
||||
)
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = ['ModelProbe', 'ModelSearch',
|
||||
'InvalidModelConfigException',
|
||||
'DuplicateModelException',
|
||||
'ModelConfigFactory',
|
||||
'BaseModelType',
|
||||
'ModelType',
|
||||
'SubModelType',
|
||||
'ModelVariantType',
|
||||
'ModelFormat',
|
||||
'SchedulerPredictionType',
|
||||
'AnyModelConfig',
|
||||
]
|
||||
|
@ -129,7 +129,6 @@ class ModelProbe(object):
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
print(f'DEBUG: model_type={model_type}')
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
@ -150,14 +149,19 @@ class ModelProbe(object):
|
||||
fields['original_hash'] = fields.get('original_hash') or hash
|
||||
fields['current_hash'] = fields.get('current_hash') or hash
|
||||
|
||||
# additional work for main models
|
||||
if fields['type'] == ModelType.Main:
|
||||
if fields['format'] == ModelFormat.Checkpoint:
|
||||
fields['config'] = cls._get_config_path(model_path, fields['base'], fields['variant'], fields['prediction_type']).as_posix()
|
||||
elif fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields['type'] in [ModelType.Main, ModelType.ControlNet] and fields['format'] == ModelFormat.Checkpoint:
|
||||
fields['config'] = cls._get_checkpoint_config_path(model_path,
|
||||
model_type=fields['type'],
|
||||
base_type=fields['base'],
|
||||
variant_type=fields['variant'],
|
||||
prediction_type=fields['prediction_type']).as_posix()
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields['type'] == ModelType.Main and fields['format'] in [ModelFormat.Onnx, ModelFormat.Olive, ModelFormat.Diffusers]:
|
||||
fields['upcast_attention'] = fields.get('upcast_attention') or (
|
||||
fields['base'] == BaseModelType.StableDiffusion2 and fields['prediction_type'] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields)
|
||||
return model_info
|
||||
@ -243,18 +247,27 @@ class ModelProbe(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_config_path(cls,
|
||||
model_path: Path,
|
||||
base_type: BaseModelType,
|
||||
variant: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType) -> Path:
|
||||
def _get_checkpoint_config_path(cls,
|
||||
model_path: Path,
|
||||
model_type: ModelType,
|
||||
base_type: BaseModelType,
|
||||
variant_type: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType) -> Path:
|
||||
|
||||
# look for a YAML file adjacent to the model file first
|
||||
possible_conf = model_path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
config_file = LEGACY_CONFIGS[base_type][variant]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
config_file = "../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
else:
|
||||
raise InvalidModelConfigException(f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}")
|
||||
assert isinstance(config_file, str)
|
||||
return Path(config_file)
|
||||
|
||||
@classmethod
|
||||
|
195
invokeai/backend/model_manager/search.py
Normal file
195
invokeai/backend/model_manager/search.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class and implementation for recursive directory search for models.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||
|
||||
def find_main_models(model: Path) -> bool:
|
||||
info = ModelProbe.probe(model)
|
||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
search = ModelSearch(on_model_found=report_it)
|
||||
found = search.search('/tmp/models')
|
||||
print(found) # list of matching model paths
|
||||
print(search.stats) # search stats
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
items_scanned: int = 0
|
||||
models_found: int = 0
|
||||
models_filtered: int = 0
|
||||
|
||||
|
||||
class ModelSearchBase(ABC, BaseModel):
|
||||
"""
|
||||
Abstract directory traversal model search class
|
||||
|
||||
Usage:
|
||||
search = ModelSearchBase(
|
||||
on_search_started = search_started_callback,
|
||||
on_search_completed = search_completed_callback,
|
||||
on_model_found = model_found_callback,
|
||||
)
|
||||
models_found = search.search('/path/to/directory')
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def search_started(self) -> None:
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the root search directory to the Callable `on_search_started`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_found(self, model: Path) -> None:
|
||||
"""
|
||||
Called when a model is found during search.
|
||||
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
|
||||
Passes the model's Path to the Callable `on_model_found`.
|
||||
This Callable receives the path to the model and returns a boolean
|
||||
to indicate whether the model should be returned in the search
|
||||
results.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_completed(self) -> None:
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
"""
|
||||
Recursively search for models in `directory` and return a set of model paths.
|
||||
|
||||
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||
Callables will be invoked during the search.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
Implementation of ModelSearch with callbacks.
|
||||
Usage:
|
||||
search = ModelSearch()
|
||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
directory: Path = Field(default=None)
|
||||
models_found: Set[Path] = Field(default=None)
|
||||
scanned_dirs: Set[Path] = Field(default=None)
|
||||
pruned_paths: Set[Path] = Field(default=None)
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
self.scanned_dirs = set()
|
||||
self.pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
def model_found(self, model: Path) -> None:
|
||||
self.stats.models_found += 1
|
||||
if not self.on_model_found:
|
||||
self.stats.models_filtered += 1
|
||||
self.models_found.add(model)
|
||||
return
|
||||
if self.on_model_found(model):
|
||||
self.stats.models_filtered += 1
|
||||
self.models_found.add(model)
|
||||
|
||||
def search_completed(self) -> None:
|
||||
if self.on_search_completed:
|
||||
self.on_search_completed(self._models_found)
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self.search_completed()
|
||||
return self.models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self.pruned_paths.add(Path(root))
|
||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
||||
continue
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self.scanned_dirs:
|
||||
self.scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
self.scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self.scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
@ -1,9 +1,13 @@
|
||||
#!/bin/env python
|
||||
|
||||
"""Little command-line utility for probing a model on disk."""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager import ModelProbe, InvalidModelConfigException
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
@ -14,5 +18,8 @@ parser.add_argument(
|
||||
args = parser.parse_args()
|
||||
|
||||
for path in args.model_path:
|
||||
info = ModelProbe().probe(path)
|
||||
print(f"{path}: {info}")
|
||||
try:
|
||||
info = ModelProbe.probe(path)
|
||||
print(f"{path}:{info.model_dump_json(indent=4)}")
|
||||
except InvalidModelConfigException as exc:
|
||||
print(exc)
|
||||
|
Loading…
Reference in New Issue
Block a user