mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add back the heuristic_import()
method and extend repo_ids to arbitrary file paths
This commit is contained in:
parent
d56337f2d8
commit
195768c9ee
@ -446,6 +446,44 @@ required parameters:
|
|||||||
|
|
||||||
Once initialized, the installer will provide the following methods:
|
Once initialized, the installer will provide the following methods:
|
||||||
|
|
||||||
|
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||||
|
|
||||||
|
This is a simplified interface to the installer which takes a source
|
||||||
|
string, an optional model configuration dictionary and an optional
|
||||||
|
access token.
|
||||||
|
|
||||||
|
The `source` is a string that can be any of these forms
|
||||||
|
|
||||||
|
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||||
|
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||||
|
3. A HuggingFace repo_id with any of the following formats:
|
||||||
|
- `model/name` -- entire model
|
||||||
|
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||||
|
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||||
|
- `model/name::vae` -- vae submodel, using default precision
|
||||||
|
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||||
|
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||||
|
|
||||||
|
Note that by specifying a relative path to the top of the HuggingFace
|
||||||
|
repo, you can download and install arbitrary models files.
|
||||||
|
|
||||||
|
The variant, if not provided, will be automatically filled in with
|
||||||
|
`fp32` if the user has requested full precision, and `fp16`
|
||||||
|
otherwise. If a variant that does not exist is requested, then the
|
||||||
|
method will install whatever HuggingFace returns as its default
|
||||||
|
revision.
|
||||||
|
|
||||||
|
`config` is an optional dict of values that will override the
|
||||||
|
autoprobed values for model type, base, scheduler prediction type, and
|
||||||
|
so forth. See [Model configuration and
|
||||||
|
probing](#Model-configuration-and-probing) for details.
|
||||||
|
|
||||||
|
`access_token` is an optional access token for accessing resources
|
||||||
|
that need authentication.
|
||||||
|
|
||||||
|
The method will return a `ModelInstallJob`. This object is discussed
|
||||||
|
at length in the following section.
|
||||||
|
|
||||||
#### install_job = installer.import_model()
|
#### install_job = installer.import_model()
|
||||||
|
|
||||||
The `import_model()` method is the core of the installer. The
|
The `import_model()` method is the core of the installer. The
|
||||||
@ -464,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif
|
|||||||
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
||||||
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
||||||
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
||||||
|
source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file
|
||||||
|
|
||||||
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||||
|
|
||||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||||
install_job = installer.install_model(source)
|
install_job = installer.install_model(source)
|
||||||
@ -522,7 +561,6 @@ can be passed to `import_model()`.
|
|||||||
attributes returned by the model prober. See the section below for
|
attributes returned by the model prober. See the section below for
|
||||||
details.
|
details.
|
||||||
|
|
||||||
|
|
||||||
#### LocalModelSource
|
#### LocalModelSource
|
||||||
|
|
||||||
This is used for a model that is located on a locally-accessible Posix
|
This is used for a model that is located on a locally-accessible Posix
|
||||||
@ -715,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return
|
|||||||
True if the job is in the complete, errored or cancelled states.
|
True if the job is in the complete, errored or cancelled states.
|
||||||
|
|
||||||
|
|
||||||
#### Model confguration and probing
|
#### Model configuration and probing
|
||||||
|
|
||||||
The install service uses the `invokeai.backend.model_manager.probe`
|
The install service uses the `invokeai.backend.model_manager.probe`
|
||||||
module during import to determine the model's type, base type, and
|
module during import to determine the model's type, base type, and
|
||||||
@ -1106,7 +1144,7 @@ job = queue.create_download_job(
|
|||||||
event_handlers=[my_handler1, my_handler2], # if desired
|
event_handlers=[my_handler1, my_handler2], # if desired
|
||||||
start=True,
|
start=True,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The `filename` argument forces the downloader to use the specified
|
The `filename` argument forces the downloader to use the specified
|
||||||
name for the file rather than the name provided by the remote source,
|
name for the file rather than the name provided by the remote source,
|
||||||
@ -1427,9 +1465,9 @@ set of keys to the corresponding model config objects.
|
|||||||
Find all model metadata records that have the given author and return
|
Find all model metadata records that have the given author and return
|
||||||
a set of keys to the corresponding model config objects.
|
a set of keys to the corresponding model config objects.
|
||||||
|
|
||||||
# The remainder of this documentation is provisional, pending implementation of the Load service
|
***
|
||||||
|
|
||||||
## Let's get loaded, the lowdown on ModelLoadService
|
## The Lowdown on the ModelLoadService
|
||||||
|
|
||||||
The `ModelLoadService` is responsible for loading a named model into
|
The `ModelLoadService` is responsible for loading a named model into
|
||||||
memory so that it can be used for inference. Despite the fact that it
|
memory so that it can be used for inference. Despite the fact that it
|
||||||
|
@ -251,9 +251,75 @@ async def add_model_record(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_v2_router.post(
|
||||||
|
"/heuristic_import",
|
||||||
|
operation_id="heuristic_import_model",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The model imported successfully"},
|
||||||
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
async def heuristic_import(
|
||||||
|
source: str,
|
||||||
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
"""Install a model using a string identifier.
|
||||||
|
|
||||||
|
`source` can be any of the following.
|
||||||
|
|
||||||
|
1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors')
|
||||||
|
2. A Url pointing to a single downloadable model file
|
||||||
|
3. A HuggingFace repo_id with any of the following formats:
|
||||||
|
- model/name
|
||||||
|
- model/name:fp16:vae
|
||||||
|
- model/name::vae -- use default precision
|
||||||
|
- model/name:fp16:path/to/model.safetensors
|
||||||
|
- model/name::path/to/model.safetensors
|
||||||
|
|
||||||
|
`config` is an optional dict containing model configuration values that will override
|
||||||
|
the ones that are probed automatically.
|
||||||
|
|
||||||
|
`access_token` is an optional access token for use with Urls that require
|
||||||
|
authentication.
|
||||||
|
|
||||||
|
Models will be downloaded, probed, configured and installed in a
|
||||||
|
series of background threads. The return object has `status` attribute
|
||||||
|
that can be used to monitor progress.
|
||||||
|
|
||||||
|
See the documentation for `import_model_record` for more information on
|
||||||
|
interpreting the job information returned by this route.
|
||||||
|
"""
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
|
result: ModelInstallJob = installer.heuristic_import(
|
||||||
|
source=source,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
logger.info(f"Started installation of {source}")
|
||||||
|
except UnknownModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=424, detail=str(e))
|
||||||
|
except InvalidModelException as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=415)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@model_manager_v2_router.post(
|
@model_manager_v2_router.post(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="import_model_record",
|
operation_id="import_model",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
@ -269,7 +335,7 @@ async def import_model(
|
|||||||
default=None,
|
default=None,
|
||||||
),
|
),
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
"""Add a model using its local path, repo_id, or remote URL.
|
"""Install a model using its local path, repo_id, or remote URL.
|
||||||
|
|
||||||
Models will be downloaded, probed, configured and installed in a
|
Models will be downloaded, probed, configured and installed in a
|
||||||
series of background threads. The return object has `status` attribute
|
series of background threads. The return object has `status` attribute
|
||||||
|
@ -49,7 +49,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
download_queue,
|
download_queue,
|
||||||
images,
|
images,
|
||||||
model_manager_v2,
|
model_manager_v2,
|
||||||
models,
|
|
||||||
session_queue,
|
session_queue,
|
||||||
sessions,
|
sessions,
|
||||||
utilities,
|
utilities,
|
||||||
|
@ -127,8 +127,8 @@ class HFModelSource(StringLikeSource):
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return string version of repoid when string rep needed."""
|
"""Return string version of repoid when string rep needed."""
|
||||||
base: str = self.repo_id
|
base: str = self.repo_id
|
||||||
|
base += f":{self.variant or ''}"
|
||||||
base += f":{self.subfolder}" if self.subfolder else ""
|
base += f":{self.subfolder}" if self.subfolder else ""
|
||||||
base += f" ({self.variant})" if self.variant else ""
|
|
||||||
return base
|
return base
|
||||||
|
|
||||||
|
|
||||||
@ -324,6 +324,43 @@ class ModelInstallServiceBase(ABC):
|
|||||||
:returns id: The string ID of the registered model.
|
:returns id: The string ID of the registered model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def heuristic_import(
|
||||||
|
self,
|
||||||
|
source: str,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||||
|
|
||||||
|
:param source: String source
|
||||||
|
:param config: Optional dict. Any fields in this dict
|
||||||
|
will override corresponding autoassigned probe fields in the
|
||||||
|
model's config record as described in `import_model()`.
|
||||||
|
:param access_token: Optional access token for remote sources.
|
||||||
|
|
||||||
|
The source can be:
|
||||||
|
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
|
||||||
|
2. An http or https URL (`https://foo.bar/foo`)
|
||||||
|
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
|
||||||
|
|
||||||
|
We extend the HuggingFace repo_id syntax to include the variant and the
|
||||||
|
subfolder or path. The following are acceptable alternatives:
|
||||||
|
stabilityai/stable-diffusion-v4
|
||||||
|
stabilityai/stable-diffusion-v4:fp16
|
||||||
|
stabilityai/stable-diffusion-v4:fp16:vae
|
||||||
|
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
stabilityai/stable-diffusion-v4:onnx:vae
|
||||||
|
|
||||||
|
Because a local file path can look like a huggingface repo_id, the logic
|
||||||
|
first checks whether the path exists on disk, and if not, it is treated as
|
||||||
|
a parseable huggingface repo.
|
||||||
|
|
||||||
|
The previous support for recursing into a local folder and loading all model-like files
|
||||||
|
has been removed.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def import_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
|
@ -50,6 +50,7 @@ from .model_install_base import (
|
|||||||
ModelInstallJob,
|
ModelInstallJob,
|
||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
ModelSource,
|
ModelSource,
|
||||||
|
StringLikeSource,
|
||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -177,6 +178,34 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
info,
|
info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def heuristic_import(
|
||||||
|
self,
|
||||||
|
source: str,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
|
source_obj: Optional[StringLikeSource] = None
|
||||||
|
|
||||||
|
if Path(source).exists(): # A local file or directory
|
||||||
|
source_obj = LocalModelSource(path=Path(source))
|
||||||
|
elif match := re.match(hf_repoid_re, source):
|
||||||
|
source_obj = HFModelSource(
|
||||||
|
repo_id=match.group(1),
|
||||||
|
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||||
|
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
|
source_obj = URLModelSource(
|
||||||
|
url=AnyHttpUrl(source),
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
|
return self.import_model(source_obj, config)
|
||||||
|
|
||||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||||
if similar_jobs:
|
if similar_jobs:
|
||||||
@ -571,6 +600,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||||
# being held in a daemon thread.
|
# being held in a daemon thread.
|
||||||
|
if len(remote_files) == 0:
|
||||||
|
raise ValueError(f"{source}: No downloadable files found")
|
||||||
tmpdir = Path(
|
tmpdir = Path(
|
||||||
mkdtemp(
|
mkdtemp(
|
||||||
dir=self._app_config.models_path,
|
dir=self._app_config.models_path,
|
||||||
@ -586,6 +617,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
bytes=0,
|
bytes=0,
|
||||||
total_bytes=0,
|
total_bytes=0,
|
||||||
)
|
)
|
||||||
|
# In the event that there is a subfolder specified in the source,
|
||||||
|
# we need to remove it from the destination path in order to avoid
|
||||||
|
# creating unwanted subfolders
|
||||||
|
if hasattr(source, "subfolder") and source.subfolder:
|
||||||
|
root = Path(remote_files[0].path.parts[0])
|
||||||
|
subfolder = root / source.subfolder
|
||||||
|
else:
|
||||||
|
root = Path(".")
|
||||||
|
subfolder = Path(".")
|
||||||
|
|
||||||
# we remember the path up to the top of the tmpdir so that it may be
|
# we remember the path up to the top of the tmpdir so that it may be
|
||||||
# removed safely at the end of the install process.
|
# removed safely at the end of the install process.
|
||||||
install_job._install_tmpdir = tmpdir
|
install_job._install_tmpdir = tmpdir
|
||||||
@ -595,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.debug(f"remote_files={remote_files}")
|
self._logger.debug(f"remote_files={remote_files}")
|
||||||
for model_file in remote_files:
|
for model_file in remote_files:
|
||||||
url = model_file.url
|
url = model_file.url
|
||||||
path = model_file.path
|
path = root / model_file.path.relative_to(subfolder)
|
||||||
self._logger.info(f"Downloading {url} => {path}")
|
self._logger.info(f"Downloading {url} => {path}")
|
||||||
install_job.total_bytes += model_file.size
|
install_job.total_bytes += model_file.size
|
||||||
assert hasattr(source, "access_token")
|
assert hasattr(source, "access_token")
|
||||||
|
@ -36,6 +36,11 @@ def filter_files(
|
|||||||
"""
|
"""
|
||||||
variant = variant or ModelRepoVariant.DEFAULT
|
variant = variant or ModelRepoVariant.DEFAULT
|
||||||
paths: List[Path] = []
|
paths: List[Path] = []
|
||||||
|
root = files[0].parts[0]
|
||||||
|
|
||||||
|
# if the subfolder is a single file, then bypass the selection and just return it
|
||||||
|
if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]:
|
||||||
|
return [root / subfolder]
|
||||||
|
|
||||||
# Start by filtering on model file extensions, discarding images, docs, etc
|
# Start by filtering on model file extensions, discarding images, docs, etc
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -61,6 +66,7 @@ def filter_files(
|
|||||||
|
|
||||||
# limit search to subfolder if requested
|
# limit search to subfolder if requested
|
||||||
if subfolder:
|
if subfolder:
|
||||||
|
subfolder = root / subfolder
|
||||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||||
|
|
||||||
# _filter_by_variant uniquifies the paths and returns a set
|
# _filter_by_variant uniquifies the paths and returns a set
|
||||||
|
Loading…
Reference in New Issue
Block a user