add model directory search route

This commit is contained in:
Lincoln Stein 2023-07-14 11:14:33 -04:00
parent e8d531b987
commit ad076b1174
4 changed files with 172 additions and 55 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Literal, List, Optional, Union from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -192,6 +193,23 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
@models_router.put( @models_router.put(
"/merge/{base_model}", "/merge/{base_model}",
operation_id="merge_models", operation_id="merge_models",

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
ModelMerger, ModelMerger,
MergeInterpolationMethod, MergeInterpolationMethod,
) )
from invokeai.backend.model_management.model_search import FindModels
import torch import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
@ -231,6 +231,13 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None: def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
@ -558,3 +565,10 @@ class ModelManagerService(ModelManagerServiceBase):
except AssertionError as e: except AssertionError as e:
raise ValueError(e) raise ValueError(e)
return result return result
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels(directory,self.logger)
return search.list_models()

View File

@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
@ -823,6 +824,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self)->Dict[str, AddModelResult]: def autoimport(self)->Dict[str, AddModelResult]:
''' '''
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
@ -831,62 +833,41 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger)
self.installer = installer
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
def on_model_found(self, model: Path):
if model not in self.ignore:
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
def models_found(self):
return self.new_models_found
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
model_manager = self, model_manager = self,
prediction_type_helper = ask_user_for_prediction_type, prediction_type_helper = ask_user_for_prediction_type,
) )
scanned_dirs = set()
config = self.app_config config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
for autodir in [config.autoimport_dir,
config.lora_dir, config.lora_dir,
config.embedding_dir, config.embedding_dir,
config.controlnet_dir]: config.controlnet_dir]
if autodir is None: }
continue scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
scanner.search()
self.logger.info(f'Scanning {autodir} for models to import') return scanner.models_found()
installed = dict()
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = dict()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path)
except ValueError as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
except ValueError as e:
self.logger.warning(str(e))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
return installed
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: Set[str],
@ -924,3 +905,4 @@ class ModelManager(object):
successfully_installed.update(installed) successfully_installed.update(installed)
self.commit() self.commit()
return successfully_installed return successfully_installed

View File

@ -0,0 +1,103 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, types
from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
if str(Path(root).name).startswith('.'):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self,model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return self.models_found