add back the heuristic_import() method and extend repo_ids to arbitrary file paths

This commit is contained in:
Lincoln Stein 2024-02-11 23:37:49 -05:00 committed by Brandon Rising
parent d56337f2d8
commit 195768c9ee
6 changed files with 199 additions and 12 deletions

View File

@ -446,6 +446,44 @@ required parameters:
Once initialized, the installer will provide the following methods: Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token])
This is a simplified interface to the installer which takes a source
string, an optional model configuration dictionary and an optional
access token.
The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats:
- `model/name` -- entire model
- `model/name:fp32` -- entire model, using the fp32 variant
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
- `model/name::vae` -- vae submodel, using default precision
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files.
The variant, if not provided, will be automatically filled in with
`fp32` if the user has requested full precision, and `fp16`
otherwise. If a variant that does not exist is requested, then the
method will install whatever HuggingFace returns as its default
revision.
`config` is an optional dict of values that will override the
autoprobed values for model type, base, scheduler prediction type, and
so forth. See [Model configuration and
probing](#Model-configuration-and-probing) for details.
`access_token` is an optional access token for accessing resources
that need authentication.
The method will return a `ModelInstallJob`. This object is discussed
at length in the following section.
#### install_job = installer.import_model() #### install_job = installer.import_model()
The `import_model()` method is the core of the installer. The The `import_model()` method is the core of the installer. The
@ -464,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
for source in [source1, source2, source3, source4, source5, source6, source7]: for source in [source1, source2, source3, source4, source5, source6, source7]:
install_job = installer.install_model(source) install_job = installer.install_model(source)
@ -522,7 +561,6 @@ can be passed to `import_model()`.
attributes returned by the model prober. See the section below for attributes returned by the model prober. See the section below for
details. details.
#### LocalModelSource #### LocalModelSource
This is used for a model that is located on a locally-accessible Posix This is used for a model that is located on a locally-accessible Posix
@ -715,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states. True if the job is in the complete, errored or cancelled states.
#### Model confguration and probing #### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe` The install service uses the `invokeai.backend.model_manager.probe`
module during import to determine the model's type, base type, and module during import to determine the model's type, base type, and
@ -1106,7 +1144,7 @@ job = queue.create_download_job(
event_handlers=[my_handler1, my_handler2], # if desired event_handlers=[my_handler1, my_handler2], # if desired
start=True, start=True,
) )
``` ```
The `filename` argument forces the downloader to use the specified The `filename` argument forces the downloader to use the specified
name for the file rather than the name provided by the remote source, name for the file rather than the name provided by the remote source,
@ -1427,9 +1465,9 @@ set of keys to the corresponding model config objects.
Find all model metadata records that have the given author and return Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects. a set of keys to the corresponding model config objects.
# The remainder of this documentation is provisional, pending implementation of the Load service ***
## Let's get loaded, the lowdown on ModelLoadService ## The Lowdown on the ModelLoadService
The `ModelLoadService` is responsible for loading a named model into The `ModelLoadService` is responsible for loading a named model into
memory so that it can be used for inference. Despite the fact that it memory so that it can be used for inference. Despite the fact that it

View File

@ -251,9 +251,75 @@ async def add_model_record(
return result return result
@model_manager_v2_router.post(
"/heuristic_import",
operation_id="heuristic_import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def heuristic_import(
source: str,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
`source` can be any of the following.
1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors')
2. A Url pointing to a single downloadable model file
3. A HuggingFace repo_id with any of the following formats:
- model/name
- model/name:fp16:vae
- model/name::vae -- use default precision
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`access_token` is an optional access token for use with Urls that require
authentication.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
See the documentation for `import_model_record` for more information on
interpreting the job information returned by this route.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.heuristic_import(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_manager_v2_router.post( @model_manager_v2_router.post(
"/import", "/import",
operation_id="import_model_record", operation_id="import_model",
responses={ responses={
201: {"description": "The model imported successfully"}, 201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
@ -269,7 +335,7 @@ async def import_model(
default=None, default=None,
), ),
) -> ModelInstallJob: ) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL. """Install a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute series of background threads. The return object has `status` attribute

View File

@ -49,7 +49,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
download_queue, download_queue,
images, images,
model_manager_v2, model_manager_v2,
models,
session_queue, session_queue,
sessions, sessions,
utilities, utilities,

View File

@ -127,8 +127,8 @@ class HFModelSource(StringLikeSource):
def __str__(self) -> str: def __str__(self) -> str:
"""Return string version of repoid when string rep needed.""" """Return string version of repoid when string rep needed."""
base: str = self.repo_id base: str = self.repo_id
base += f":{self.variant or ''}"
base += f":{self.subfolder}" if self.subfolder else "" base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base return base
@ -324,6 +324,43 @@ class ModelInstallServiceBase(ABC):
:returns id: The string ID of the registered model. :returns id: The string ID of the registered model.
""" """
@abstractmethod
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources.
The source can be:
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
2. An http or https URL (`https://foo.bar/foo`)
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path. The following are acceptable alternatives:
stabilityai/stable-diffusion-v4
stabilityai/stable-diffusion-v4:fp16
stabilityai/stable-diffusion-v4:fp16:vae
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
stabilityai/stable-diffusion-v4:onnx:vae
Because a local file path can look like a huggingface repo_id, the logic
first checks whether the path exists on disk, and if not, it is treated as
a parseable huggingface repo.
The previous support for recursing into a local folder and loading all model-like files
has been removed.
"""
pass
@abstractmethod @abstractmethod
def import_model( def import_model(
self, self,

View File

@ -50,6 +50,7 @@ from .model_install_base import (
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase, ModelInstallServiceBase,
ModelSource, ModelSource,
StringLikeSource,
URLModelSource, URLModelSource,
) )
@ -177,6 +178,34 @@ class ModelInstallService(ModelInstallServiceBase):
info, info,
) )
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs: if similar_jobs:
@ -571,6 +600,8 @@ class ModelInstallService(ModelInstallServiceBase):
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is # Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread. # being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path( tmpdir = Path(
mkdtemp( mkdtemp(
dir=self._app_config.models_path, dir=self._app_config.models_path,
@ -586,6 +617,16 @@ class ModelInstallService(ModelInstallServiceBase):
bytes=0, bytes=0,
total_bytes=0, total_bytes=0,
) )
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be # we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process. # removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir install_job._install_tmpdir = tmpdir
@ -595,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"remote_files={remote_files}") self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files: for model_file in remote_files:
url = model_file.url url = model_file.url
path = model_file.path path = root / model_file.path.relative_to(subfolder)
self._logger.info(f"Downloading {url} => {path}") self._logger.info(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size install_job.total_bytes += model_file.size
assert hasattr(source, "access_token") assert hasattr(source, "access_token")

View File

@ -36,6 +36,11 @@ def filter_files(
""" """
variant = variant or ModelRepoVariant.DEFAULT variant = variant or ModelRepoVariant.DEFAULT
paths: List[Path] = [] paths: List[Path] = []
root = files[0].parts[0]
# if the subfolder is a single file, then bypass the selection and just return it
if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]:
return [root / subfolder]
# Start by filtering on model file extensions, discarding images, docs, etc # Start by filtering on model file extensions, discarding images, docs, etc
for file in files: for file in files:
@ -61,6 +66,7 @@ def filter_files(
# limit search to subfolder if requested # limit search to subfolder if requested
if subfolder: if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)] paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set # _filter_by_variant uniquifies the paths and returns a set