mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add the import model router
This commit is contained in:
committed by
psychedelicious
parent
0988725c1b
commit
96bf92ead4
@ -18,7 +18,7 @@ from tqdm import tqdm
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from ..util.logging import InvokeAILogger
|
||||
@ -166,17 +166,22 @@ class ModelInstall(object):
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_install(path)
|
||||
self.heuristic_import(path)
|
||||
job += 1
|
||||
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_install(self,
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Set[Path]:
|
||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
'''
|
||||
|
||||
if not models_installed:
|
||||
models_installed = set()
|
||||
models_installed = dict()
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
@ -185,24 +190,24 @@ class ModelInstall(object):
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.add(self._install_path(path))
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||
models_installed.add(self._install_path(path))
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_install(child, models_installed=models_installed)
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(path).split('/')) == 2:
|
||||
models_installed.add(self._install_repo(str(path)))
|
||||
models_installed.update(self._install_repo(str(path)))
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.add(self._install_url(model_path_id_or_url))
|
||||
models_installed.update(self._install_url(model_path_id_or_url))
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
@ -214,24 +219,25 @@ class ModelInstall(object):
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||
try:
|
||||
# logger.debug(f'Probing {path}')
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
model_result = self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return path
|
||||
return {}
|
||||
return {str(path): model_result}
|
||||
|
||||
def _install_url(self, url: str)->Path:
|
||||
def _install_url(self, url: str)->dict:
|
||||
# copy to a staging area, probe, import and delete
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
@ -244,7 +250,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->Path:
|
||||
def _install_repo(self, repo_id: str)->dict:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
|
Reference in New Issue
Block a user