import model returns 404 for invalid path, 409 for duplicate model

This commit is contained in:
Lincoln Stein 2023-07-04 09:59:11 -04:00
parent 92b163e95c
commit c1c49d9a76
3 changed files with 69 additions and 66 deletions

View File

@ -116,19 +116,23 @@ async def update_model(
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
) )
async def import_model( async def import_model(
name: str = Query(description="A model path, repo_id or URL to import"), name: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """ """ Add a model using its local path, repo_id, or remote URL """
items_to_import = {name} items_to_import = {name}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
@ -140,9 +144,13 @@ async def import_model(
info = info, info = info,
status = "success", status = "success",
) )
else: except KeyError as e:
logger.error(f'Model {name} not imported') logger.error(str(e))
raise HTTPException(status_code=404, detail=f'Model {name} not found') raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{model_name}", "/{model_name}",

View File

@ -166,14 +166,18 @@ class ModelInstall(object):
# add requested models # add requested models
for path in selections.install_models: for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]') logger.info(f'Installing {path} [{job}/{jobs}]')
try:
self.heuristic_import(path) self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1 job += 1
self.mgr.commit() self.mgr.commit()
def heuristic_import(self, def heuristic_import(self,
model_path_id_or_url: Union[str,Path], model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Dict[str, AddModelResult]: models_installed: Set[Path]=None,
)->Dict[str, AddModelResult]:
''' '''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation :param models_installed: Set of installed models, used for recursive invocation
@ -187,14 +191,13 @@ class ModelInstall(object):
self.current_id = model_path_id_or_url self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url) path = Path(model_path_id_or_url)
try:
# checkpoint file, or similar # checkpoint file, or similar
if path.is_file(): if path.is_file():
models_installed.update(self._install_path(path)) models_installed.update({str(path):self._install_path(path)})
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
models_installed.update(self._install_path(path)) models_installed.update({str(path): self._install_path(path)})
# recursive scan # recursive scan
elif path.is_dir(): elif path.is_dir():
@ -202,42 +205,34 @@ class ModelInstall(object):
self.heuristic_import(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif len(str(path).split('/')) == 2: elif str(model_path_id_or_url).split('/') == 2:
models_installed.update(self._install_repo(str(path))) models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL # a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update(self._install_url(model_path_id_or_url)) models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else: else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
except ValueError as e:
logger.error(str(e))
return models_installed return models_installed
# install a model from a local path. The optional info parameter is there to prevent # install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed. # the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
try:
model_result = None model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name model_name = path.stem if info.format=='checkpoint' else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.') raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info) attributes = self._make_attributes(path,info)
model_result = self.mgr.add_model(model_name = model_name, return self.mgr.add_model(model_name = model_name,
base_model = info.base_type, base_model = info.base_type,
model_type = info.model_type, model_type = info.model_type,
model_attributes = attributes, model_attributes = attributes,
) )
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
return {}
return {str(path): model_result}
def _install_url(self, url: str)->dict: def _install_url(self, url: str)->AddModelResult:
# copy to a staging area, probe, import and delete # copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging))
@ -250,7 +245,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info) return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->dict: def _install_repo(self, repo_id: str)->AddModelResult:
hinfo = HfApi().model_info(repo_id) hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically # we try to figure out how to download this most economically

View File

@ -820,6 +820,10 @@ class ModelManager(object):
The result is a set of successfully installed models. Each element The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
''' '''
# avoid circular import here # avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
@ -829,11 +833,7 @@ class ModelManager(object):
prediction_type_helper = prediction_type_helper, prediction_type_helper = prediction_type_helper,
model_manager = self) model_manager = self)
for thing in items_to_import: for thing in items_to_import:
try:
installed = installer.heuristic_import(thing) installed = installer.heuristic_import(thing)
successfully_installed.update(installed) successfully_installed.update(installed)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
self.commit() self.commit()
return successfully_installed return successfully_installed