mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
import model returns 404 for invalid path, 409 for duplicate model
This commit is contained in:
parent
92b163e95c
commit
c1c49d9a76
@ -116,19 +116,23 @@ async def update_model(
|
|||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
name: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL """
|
""" Add a model using its local path, repo_id, or remote URL """
|
||||||
|
|
||||||
items_to_import = {name}
|
items_to_import = {name}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import = items_to_import,
|
items_to_import = items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
@ -140,9 +144,13 @@ async def import_model(
|
|||||||
info = info,
|
info = info,
|
||||||
status = "success",
|
status = "success",
|
||||||
)
|
)
|
||||||
else:
|
except KeyError as e:
|
||||||
logger.error(f'Model {name} not imported')
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{model_name}",
|
"/{model_name}",
|
||||||
|
@ -166,14 +166,18 @@ class ModelInstall(object):
|
|||||||
# add requested models
|
# add requested models
|
||||||
for path in selections.install_models:
|
for path in selections.install_models:
|
||||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
|
try:
|
||||||
self.heuristic_import(path)
|
self.heuristic_import(path)
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
logger.error(str(e))
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
model_path_id_or_url: Union[str,Path],
|
model_path_id_or_url: Union[str,Path],
|
||||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
models_installed: Set[Path]=None,
|
||||||
|
)->Dict[str, AddModelResult]:
|
||||||
'''
|
'''
|
||||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||||
:param models_installed: Set of installed models, used for recursive invocation
|
:param models_installed: Set of installed models, used for recursive invocation
|
||||||
@ -187,14 +191,13 @@ class ModelInstall(object):
|
|||||||
self.current_id = model_path_id_or_url
|
self.current_id = model_path_id_or_url
|
||||||
path = Path(model_path_id_or_url)
|
path = Path(model_path_id_or_url)
|
||||||
|
|
||||||
try:
|
|
||||||
# checkpoint file, or similar
|
# checkpoint file, or similar
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
models_installed.update(self._install_path(path))
|
models_installed.update({str(path):self._install_path(path)})
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
models_installed.update(self._install_path(path))
|
models_installed.update({str(path): self._install_path(path)})
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
@ -202,42 +205,34 @@ class ModelInstall(object):
|
|||||||
self.heuristic_import(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif len(str(path).split('/')) == 2:
|
elif str(model_path_id_or_url).split('/') == 2:
|
||||||
models_installed.update(self._install_repo(str(path)))
|
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||||
models_installed.update(self._install_url(model_path_id_or_url))
|
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
return models_installed
|
return models_installed
|
||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||||
try:
|
|
||||||
model_result = None
|
model_result = None
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
attributes = self._make_attributes(path,info)
|
attributes = self._make_attributes(path,info)
|
||||||
model_result = self.mgr.add_model(model_name = model_name,
|
return self.mgr.add_model(model_name = model_name,
|
||||||
base_model = info.base_type,
|
base_model = info.base_type,
|
||||||
model_type = info.model_type,
|
model_type = info.model_type,
|
||||||
model_attributes = attributes,
|
model_attributes = attributes,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f'{str(e)} Skipping registration.')
|
|
||||||
return {}
|
|
||||||
return {str(path): model_result}
|
|
||||||
|
|
||||||
def _install_url(self, url: str)->dict:
|
def _install_url(self, url: str)->AddModelResult:
|
||||||
# copy to a staging area, probe, import and delete
|
# copy to a staging area, probe, import and delete
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url,Path(staging))
|
location = download_with_resume(url,Path(staging))
|
||||||
@ -250,7 +245,7 @@ class ModelInstall(object):
|
|||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str)->dict:
|
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
|
@ -820,6 +820,10 @@ class ModelManager(object):
|
|||||||
The result is a set of successfully installed models. Each element
|
The result is a set of successfully installed models. Each element
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
|
|
||||||
|
May return the following exceptions:
|
||||||
|
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
||||||
|
- ValueError - a corresponding model already exists
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
@ -829,11 +833,7 @@ class ModelManager(object):
|
|||||||
prediction_type_helper = prediction_type_helper,
|
prediction_type_helper = prediction_type_helper,
|
||||||
model_manager = self)
|
model_manager = self)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
try:
|
|
||||||
installed = installer.heuristic_import(thing)
|
installed = installer.heuristic_import(thing)
|
||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
|
||||||
|
|
||||||
self.commit()
|
self.commit()
|
||||||
return successfully_installed
|
return successfully_installed
|
||||||
|
Loading…
Reference in New Issue
Block a user