add back the heuristic_import() method and extend repo_ids to arbitrary file paths

This commit is contained in:
Lincoln Stein
2024-02-11 23:37:49 -05:00
committed by psychedelicious
parent 94e8d1b6d5
commit 13a9ea35b5
6 changed files with 199 additions and 12 deletions

View File

@ -127,8 +127,8 @@ class HFModelSource(StringLikeSource):
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.variant or ''}"
base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base
@ -324,6 +324,43 @@ class ModelInstallServiceBase(ABC):
:returns id: The string ID of the registered model.
"""
@abstractmethod
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources.
The source can be:
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
2. An http or https URL (`https://foo.bar/foo`)
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path. The following are acceptable alternatives:
stabilityai/stable-diffusion-v4
stabilityai/stable-diffusion-v4:fp16
stabilityai/stable-diffusion-v4:fp16:vae
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
stabilityai/stable-diffusion-v4:onnx:vae
Because a local file path can look like a huggingface repo_id, the logic
first checks whether the path exists on disk, and if not, it is treated as
a parseable huggingface repo.
The previous support for recursing into a local folder and loading all model-like files
has been removed.
"""
pass
@abstractmethod
def import_model(
self,

View File

@ -50,6 +50,7 @@ from .model_install_base import (
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
)
@ -177,6 +178,34 @@ class ModelInstallService(ModelInstallServiceBase):
info,
)
def heuristic_import(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
@ -571,6 +600,8 @@ class ModelInstallService(ModelInstallServiceBase):
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path(
mkdtemp(
dir=self._app_config.models_path,
@ -586,6 +617,16 @@ class ModelInstallService(ModelInstallServiceBase):
bytes=0,
total_bytes=0,
)
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if hasattr(source, "subfolder") and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
@ -595,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = model_file.path
path = root / model_file.path.relative_to(subfolder)
self._logger.info(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")