diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8dbeaa3d05..8d97a1bda4 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,6 +1,7 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein +import pathlib from typing import Literal, List, Optional, Union from fastapi import Body, Path, Query, Response @@ -191,6 +192,23 @@ async def convert_model( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + +@models_router.get( + "/search", + operation_id="search_for_models", + responses={ + 200: { "description": "Directory searched successfully" }, + 404: { "description": "Invalid directory path" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def search_for_models( + search_path: pathlib.Path = Query(description="Directory path to search for models") +)->List[pathlib.Path]: + if not search_path.is_dir(): + raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") + return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) @models_router.put( "/merge/{base_model}", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 1b1c43dc11..9a6ba77c13 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -19,7 +19,7 @@ from invokeai.backend.model_management import ( ModelMerger, MergeInterpolationMethod, ) - +from invokeai.backend.model_management.model_search import FindModels import torch from invokeai.app.models.exceptions import CanceledException @@ -230,7 +230,14 @@ class ModelManagerServiceBase(ABC): :param interp: Interpolation method. None (default) """ pass - + + @abstractmethod + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + pass + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -558,3 +565,10 @@ class ModelManagerService(ModelManagerServiceBase): except AssertionError as e: raise ValueError(e) return result + + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + search = FindModels(directory,self.logger) + return search.list_models() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0476425c8b..0363b858cf 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, Chdir from .model_cache import ModelCache, ModelLocker +from .model_search import ModelSearch from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, @@ -823,6 +824,7 @@ class ModelManager(object): if (new_models_found or imported_models) and self.config_path: self.commit() + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. @@ -830,63 +832,42 @@ class ModelManager(object): # avoid circular import from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.frontend.install.model_install import ask_user_for_prediction_type - + + + class ScanAndImport(ModelSearch): + def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): + super().__init__(directories, logger) + self.installer = installer + self.ignore = ignore + + def on_search_started(self): + self.new_models_found = dict() + + def on_model_found(self, model: Path): + if model not in self.ignore: + self.new_models_found.update(self.installer.heuristic_import(model)) + + def on_search_completed(self): + self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models') + + def models_found(self): + return self.new_models_found + + installer = ModelInstall(config = self.app_config, model_manager = self, prediction_type_helper = ask_user_for_prediction_type, ) - - scanned_dirs = set() - config = self.app_config - known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} - - for autodir in [config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir]: - if autodir is None: - continue - - self.logger.info(f'Scanning {autodir} for models to import') - installed = dict() - - autodir = self.app_config.root_path / autodir - if not autodir.exists(): - continue - - items_scanned = 0 - new_models_found = dict() - - for root, dirs, files in os.walk(autodir): - items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in known_paths or path.parent in scanned_dirs: - scanned_dirs.add(path) - continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): - try: - new_models_found.update(installer.heuristic_import(path)) - scanned_dirs.add(path) - except ValueError as e: - self.logger.warning(str(e)) - - for f in files: - path = Path(root) / f - if path in known_paths or path.parent in scanned_dirs: - continue - if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - try: - import_result = installer.heuristic_import(path) - new_models_found.update(import_result) - except ValueError as e: - self.logger.warning(str(e)) - - self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') - installed.update(new_models_found) - - return installed + known_paths = {config.root_path / x['path'] for x in self.list_models()} + directories = {config.root_path / x for x in [config.autoimport_dir, + config.lora_dir, + config.embedding_dir, + config.controlnet_dir] + } + scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) + scanner.search() + return scanner.models_found() def heuristic_import(self, items_to_import: Set[str], @@ -924,3 +905,4 @@ class ModelManager(object): successfully_installed.update(installed) self.commit() return successfully_installed + diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py new file mode 100644 index 0000000000..1e282b4bb8 --- /dev/null +++ b/invokeai/backend/model_management/model_search.py @@ -0,0 +1,103 @@ +# Copyright 2023, Lincoln D. Stein and the InvokeAI Team +""" +Abstract base class for recursive directory search for models. +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Set, types +from pathlib import Path + +import invokeai.backend.util.logging as logger + +class ModelSearch(ABC): + def __init__(self, directories: List[Path], logger: types.ModuleType=logger): + """ + Initialize a recursive model directory search. + :param directories: List of directory Paths to recurse through + :param logger: Logger to use + """ + self.directories = directories + self.logger = logger + self._items_scanned = 0 + self._models_found = 0 + self._scanned_dirs = set() + self._scanned_paths = set() + self._pruned_paths = set() + + @abstractmethod + def on_search_started(self): + """ + Called before the scan starts. + """ + pass + + @abstractmethod + def on_model_found(self, model: Path): + """ + Process a found model. Raise an exception if something goes wrong. + :param model: Model to process - could be a directory or checkpoint. + """ + pass + + @abstractmethod + def on_search_completed(self): + """ + Perform some activity when the scan is completed. May use instance + variables, items_scanned and models_found + """ + pass + + def search(self): + self.on_search_started() + for dir in self.directories: + self.walk_directory(dir) + self.on_search_completed() + + def walk_directory(self, path: Path): + for root, dirs, files in os.walk(path): + if str(Path(root).name).startswith('.'): + self._pruned_paths.add(root) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self._items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path in self._scanned_paths or path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): + try: + self.on_model_found(path) + self._models_found += 1 + self._scanned_dirs.add(path) + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: + try: + self.on_model_found(path) + self._models_found += 1 + except Exception as e: + self.logger.warning(str(e)) + +class FindModels(ModelSearch): + def on_search_started(self): + self.models_found: Set[Path] = set() + + def on_model_found(self,model: Path): + self.models_found.add(model) + + def on_search_completed(self): + pass + + def list_models(self) -> List[Path]: + self.search() + return self.models_found + +