mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add model directory search route
This commit is contained in:
parent
e8d531b987
commit
ad076b1174
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -192,6 +193,23 @@ async def convert_model(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: { "description": "Directory searched successfully" },
|
||||
404: { "description": "Invalid directory path" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||
)->List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
|
@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -231,6 +231,13 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@ -558,3 +565,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels(directory,self.logger)
|
||||
return search.list_models()
|
||||
|
@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .model_search import ModelSearch
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
@ -823,6 +824,7 @@ class ModelManager(object):
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
@ -831,62 +833,41 @@ class ModelManager(object):
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||
|
||||
|
||||
class ScanAndImport(ModelSearch):
|
||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||
super().__init__(directories, logger)
|
||||
self.installer = installer
|
||||
self.ignore = ignore
|
||||
|
||||
def on_search_started(self):
|
||||
self.new_models_found = dict()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if model not in self.ignore:
|
||||
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||
|
||||
def on_search_completed(self):
|
||||
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||
|
||||
def models_found(self):
|
||||
return self.new_models_found
|
||||
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
model_manager = self,
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
||||
|
||||
for autodir in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]:
|
||||
if autodir is None:
|
||||
continue
|
||||
|
||||
self.logger.info(f'Scanning {autodir} for models to import')
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
|
||||
return installed
|
||||
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]
|
||||
}
|
||||
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||
scanner.search()
|
||||
return scanner.models_found()
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
@ -924,3 +905,4 @@ class ModelManager(object):
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
103
invokeai/backend/model_management/model_search.py
Normal file
103
invokeai/backend/model_management/model_search.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, types
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path):
|
||||
if str(Path(root).name).startswith('.'):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self,model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return self.models_found
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user