add the import model router

This commit is contained in:
Lincoln Stein
2023-07-03 19:32:54 -04:00
committed by psychedelicious
parent 0988725c1b
commit 96bf92ead4
8 changed files with 233 additions and 1489 deletions

View File

@ -18,7 +18,7 @@ from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume
from ..util.logging import InvokeAILogger
@ -166,17 +166,22 @@ class ModelInstall(object):
# add requested models
for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_install(path)
self.heuristic_import(path)
job += 1
self.mgr.commit()
def heuristic_install(self,
def heuristic_import(self,
model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Set[Path]:
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
'''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
'''
if not models_installed:
models_installed = set()
models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
@ -185,24 +190,24 @@ class ModelInstall(object):
try:
# checkpoint file, or similar
if path.is_file():
models_installed.add(self._install_path(path))
models_installed.update(self._install_path(path))
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
models_installed.add(self._install_path(path))
models_installed.update(self._install_path(path))
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_install(child, models_installed=models_installed)
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(path).split('/')) == 2:
models_installed.add(self._install_repo(str(path)))
models_installed.update(self._install_repo(str(path)))
# a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
models_installed.add(self._install_url(model_path_id_or_url))
models_installed.update(self._install_url(model_path_id_or_url))
else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
@ -214,24 +219,25 @@ class ModelInstall(object):
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
try:
# logger.debug(f'Probing {path}')
model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info)
self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
model_result = self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
model_attributes = attributes,
)
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
return path
return {}
return {str(path): model_result}
def _install_url(self, url: str)->Path:
def _install_url(self, url: str)->dict:
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
@ -244,7 +250,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->Path:
def _install_repo(self, repo_id: str)->dict:
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically