mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
|
"""
|
|
Abstract base class for recursive directory search for models.
|
|
"""
|
|
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import List, Set, types
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
|
|
|
|
class ModelSearch(ABC):
|
|
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
|
"""
|
|
Initialize a recursive model directory search.
|
|
:param directories: List of directory Paths to recurse through
|
|
:param logger: Logger to use
|
|
"""
|
|
self.directories = directories
|
|
self.logger = logger
|
|
self._items_scanned = 0
|
|
self._models_found = 0
|
|
self._scanned_dirs = set()
|
|
self._scanned_paths = set()
|
|
self._pruned_paths = set()
|
|
|
|
@abstractmethod
|
|
def on_search_started(self):
|
|
"""
|
|
Called before the scan starts.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_model_found(self, model: Path):
|
|
"""
|
|
Process a found model. Raise an exception if something goes wrong.
|
|
:param model: Model to process - could be a directory or checkpoint.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_search_completed(self):
|
|
"""
|
|
Perform some activity when the scan is completed. May use instance
|
|
variables, items_scanned and models_found
|
|
"""
|
|
pass
|
|
|
|
def search(self):
|
|
self.on_search_started()
|
|
for dir in self.directories:
|
|
self.walk_directory(dir)
|
|
self.on_search_completed()
|
|
|
|
def walk_directory(self, path: Path):
|
|
for root, dirs, files in os.walk(path, followlinks=True):
|
|
if str(Path(root).name).startswith("."):
|
|
self._pruned_paths.add(root)
|
|
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
|
continue
|
|
|
|
self._items_scanned += len(dirs) + len(files)
|
|
for d in dirs:
|
|
path = Path(root) / d
|
|
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
|
self._scanned_dirs.add(path)
|
|
continue
|
|
if any(
|
|
(path / x).exists()
|
|
for x in {
|
|
"config.json",
|
|
"model_index.json",
|
|
"learned_embeds.bin",
|
|
"pytorch_lora_weights.bin",
|
|
"image_encoder.txt",
|
|
}
|
|
):
|
|
try:
|
|
self.on_model_found(path)
|
|
self._models_found += 1
|
|
self._scanned_dirs.add(path)
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to process '{path}': {e}")
|
|
|
|
for f in files:
|
|
path = Path(root) / f
|
|
if path.parent in self._scanned_dirs:
|
|
continue
|
|
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
|
try:
|
|
self.on_model_found(path)
|
|
self._models_found += 1
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to process '{path}': {e}")
|
|
|
|
|
|
class FindModels(ModelSearch):
|
|
def on_search_started(self):
|
|
self.models_found: Set[Path] = set()
|
|
|
|
def on_model_found(self, model: Path):
|
|
self.models_found.add(model)
|
|
|
|
def on_search_completed(self):
|
|
pass
|
|
|
|
def list_models(self) -> List[Path]:
|
|
self.search()
|
|
return list(self.models_found)
|