bad implementation of diffusers folder download

This commit is contained in:
Lincoln Stein 2024-05-08 21:21:01 -07:00
parent f211c95dbc
commit b48d4a049d
5 changed files with 64 additions and 13 deletions

View File

@ -177,6 +177,7 @@ class ModelInstallJob(BaseModel):
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_do_install: Optional[bool] = PrivateAttr(default=True)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
@ -407,6 +408,21 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_diffusers_model(
self,
source: HFModelSource,
download_to: Path,
) -> ModelInstallJob:
"""
Download, but do not install, a diffusers model.
:param source: An HFModelSource object containing a repo_id
:param download_to: Path to directory that will contain the downloaded model.
Returns: a ModelInstallJob
"""
@abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -249,6 +249,9 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job)
return install_job
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
return self._import_from_hf(source, download_path=download_to)
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
@ -641,7 +644,12 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False,
)
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
download_path: Optional[Path] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
if not source.access_token:
@ -660,9 +668,14 @@ class ModelInstallService(ModelInstallServiceBase):
config=config,
remote_files=remote_files,
metadata=metadata,
download_path=download_path,
)
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
def _import_from_url(
self,
source: URLModelSource,
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially
metadata = None
fetcher = None
@ -676,6 +689,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
print(remote_files)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model(
@ -691,13 +705,14 @@ class ModelInstallService(ModelInstallServiceBase):
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
download_path: Optional[Path] = None, # if defined, download only - don't install!
) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path(
destdir = download_path or Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
@ -708,7 +723,7 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
@ -722,9 +737,10 @@ class ModelInstallService(ModelInstallServiceBase):
root = Path(".")
subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be
# we remember the path up to the top of the destdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
install_job._install_tmpdir = destdir
install_job._do_install = download_path is None
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
files_string = "file" if len(remote_files) == 1 else "file"
@ -736,7 +752,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest = destdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
@ -805,6 +821,7 @@ class ModelInstallService(ModelInstallServiceBase):
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self._signal_job_downloads_done(install_job)
if install_job._do_install:
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed

View File

@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier - empty str
Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"

View File

@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant),
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,

View File

@ -14,6 +14,7 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
@ -21,7 +22,13 @@ from invokeai.app.services.model_install import (
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system
@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus
@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
}
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
job = mm2_installer.download_diffusers_model(source, tmp_path)
mm2_installer.wait_for_installs(timeout=5)
print(job.local_path)
assert job.status == InstallStatus.DOWNLOADS_DONE
assert (tmp_path / "sdxl-turbo").exists()
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
@ -327,7 +345,7 @@ def test_other_error_during_install(
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
@pytest.mark.timeout(timeout=20, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params