mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add model directory search route
This commit is contained in:
parent
e8d531b987
commit
ad076b1174
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||||
|
|
||||||
|
|
||||||
|
import pathlib
|
||||||
from typing import Literal, List, Optional, Union
|
from typing import Literal, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
@ -192,6 +193,23 @@ async def convert_model(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/search",
|
||||||
|
operation_id="search_for_models",
|
||||||
|
responses={
|
||||||
|
200: { "description": "Directory searched successfully" },
|
||||||
|
404: { "description": "Invalid directory path" },
|
||||||
|
},
|
||||||
|
status_code = 200,
|
||||||
|
response_model = List[pathlib.Path]
|
||||||
|
)
|
||||||
|
async def search_for_models(
|
||||||
|
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||||
|
)->List[pathlib.Path]:
|
||||||
|
if not search_path.is_dir():
|
||||||
|
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||||
|
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
|
@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
|
|||||||
ModelMerger,
|
ModelMerger,
|
||||||
MergeInterpolationMethod,
|
MergeInterpolationMethod,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
@ -231,6 +231,13 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_for_models(self, directory: Path)->List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -558,3 +565,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def search_for_models(self, directory: Path)->List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
search = FindModels(directory,self.logger)
|
||||||
|
return search.list_models()
|
||||||
|
@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
|
from .model_search import ModelSearch
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, SubModelType,
|
BaseModelType, ModelType, SubModelType,
|
||||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
@ -823,6 +824,7 @@ class ModelManager(object):
|
|||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
|
|
||||||
def autoimport(self)->Dict[str, AddModelResult]:
|
def autoimport(self)->Dict[str, AddModelResult]:
|
||||||
'''
|
'''
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
@ -831,62 +833,41 @@ class ModelManager(object):
|
|||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
|
|
||||||
|
class ScanAndImport(ModelSearch):
|
||||||
|
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||||
|
super().__init__(directories, logger)
|
||||||
|
self.installer = installer
|
||||||
|
self.ignore = ignore
|
||||||
|
|
||||||
|
def on_search_started(self):
|
||||||
|
self.new_models_found = dict()
|
||||||
|
|
||||||
|
def on_model_found(self, model: Path):
|
||||||
|
if model not in self.ignore:
|
||||||
|
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||||
|
|
||||||
|
def on_search_completed(self):
|
||||||
|
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||||
|
|
||||||
|
def models_found(self):
|
||||||
|
return self.new_models_found
|
||||||
|
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(config = self.app_config,
|
||||||
model_manager = self,
|
model_manager = self,
|
||||||
prediction_type_helper = ask_user_for_prediction_type,
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
scanned_dirs = set()
|
|
||||||
|
|
||||||
config = self.app_config
|
config = self.app_config
|
||||||
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||||
|
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||||
for autodir in [config.autoimport_dir,
|
config.lora_dir,
|
||||||
config.lora_dir,
|
config.embedding_dir,
|
||||||
config.embedding_dir,
|
config.controlnet_dir]
|
||||||
config.controlnet_dir]:
|
}
|
||||||
if autodir is None:
|
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||||
continue
|
scanner.search()
|
||||||
|
return scanner.models_found()
|
||||||
self.logger.info(f'Scanning {autodir} for models to import')
|
|
||||||
installed = dict()
|
|
||||||
|
|
||||||
autodir = self.app_config.root_path / autodir
|
|
||||||
if not autodir.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
items_scanned = 0
|
|
||||||
new_models_found = dict()
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(autodir):
|
|
||||||
items_scanned += len(dirs) + len(files)
|
|
||||||
for d in dirs:
|
|
||||||
path = Path(root) / d
|
|
||||||
if path in known_paths or path.parent in scanned_dirs:
|
|
||||||
scanned_dirs.add(path)
|
|
||||||
continue
|
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
|
||||||
try:
|
|
||||||
new_models_found.update(installer.heuristic_import(path))
|
|
||||||
scanned_dirs.add(path)
|
|
||||||
except ValueError as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
path = Path(root) / f
|
|
||||||
if path in known_paths or path.parent in scanned_dirs:
|
|
||||||
continue
|
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
|
||||||
try:
|
|
||||||
import_result = installer.heuristic_import(path)
|
|
||||||
new_models_found.update(import_result)
|
|
||||||
except ValueError as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
|
||||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
|
||||||
installed.update(new_models_found)
|
|
||||||
|
|
||||||
return installed
|
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
@ -924,3 +905,4 @@ class ModelManager(object):
|
|||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
self.commit()
|
self.commit()
|
||||||
return successfully_installed
|
return successfully_installed
|
||||||
|
|
||||||
|
103
invokeai/backend/model_management/model_search.py
Normal file
103
invokeai/backend/model_management/model_search.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||||
|
"""
|
||||||
|
Abstract base class for recursive directory search for models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Set, types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
class ModelSearch(ABC):
|
||||||
|
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
||||||
|
"""
|
||||||
|
Initialize a recursive model directory search.
|
||||||
|
:param directories: List of directory Paths to recurse through
|
||||||
|
:param logger: Logger to use
|
||||||
|
"""
|
||||||
|
self.directories = directories
|
||||||
|
self.logger = logger
|
||||||
|
self._items_scanned = 0
|
||||||
|
self._models_found = 0
|
||||||
|
self._scanned_dirs = set()
|
||||||
|
self._scanned_paths = set()
|
||||||
|
self._pruned_paths = set()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_search_started(self):
|
||||||
|
"""
|
||||||
|
Called before the scan starts.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_model_found(self, model: Path):
|
||||||
|
"""
|
||||||
|
Process a found model. Raise an exception if something goes wrong.
|
||||||
|
:param model: Model to process - could be a directory or checkpoint.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_search_completed(self):
|
||||||
|
"""
|
||||||
|
Perform some activity when the scan is completed. May use instance
|
||||||
|
variables, items_scanned and models_found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def search(self):
|
||||||
|
self.on_search_started()
|
||||||
|
for dir in self.directories:
|
||||||
|
self.walk_directory(dir)
|
||||||
|
self.on_search_completed()
|
||||||
|
|
||||||
|
def walk_directory(self, path: Path):
|
||||||
|
for root, dirs, files in os.walk(path):
|
||||||
|
if str(Path(root).name).startswith('.'):
|
||||||
|
self._pruned_paths.add(root)
|
||||||
|
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._items_scanned += len(dirs) + len(files)
|
||||||
|
for d in dirs:
|
||||||
|
path = Path(root) / d
|
||||||
|
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||||
|
self._scanned_dirs.add(path)
|
||||||
|
continue
|
||||||
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||||
|
try:
|
||||||
|
self.on_model_found(path)
|
||||||
|
self._models_found += 1
|
||||||
|
self._scanned_dirs.add(path)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
path = Path(root) / f
|
||||||
|
if path.parent in self._scanned_dirs:
|
||||||
|
continue
|
||||||
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
|
try:
|
||||||
|
self.on_model_found(path)
|
||||||
|
self._models_found += 1
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
|
||||||
|
class FindModels(ModelSearch):
|
||||||
|
def on_search_started(self):
|
||||||
|
self.models_found: Set[Path] = set()
|
||||||
|
|
||||||
|
def on_model_found(self,model: Path):
|
||||||
|
self.models_found.add(model)
|
||||||
|
|
||||||
|
def on_search_completed(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def list_models(self) -> List[Path]:
|
||||||
|
self.search()
|
||||||
|
return self.models_found
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user