bad implementation of diffusers folder download

This commit is contained in:
Lincoln Stein 2024-05-08 21:21:01 -07:00
parent f211c95dbc
commit b48d4a049d
5 changed files with 64 additions and 13 deletions

View File

@ -177,6 +177,7 @@ class ModelInstallJob(BaseModel):
) )
# internal flags and transitory settings # internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None) _install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_do_install: Optional[bool] = PrivateAttr(default=True)
_exception: Optional[Exception] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:
@ -407,6 +408,21 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod
def download_diffusers_model(
self,
source: HFModelSource,
download_to: Path,
) -> ModelInstallJob:
"""
Download, but do not install, a diffusers model.
:param source: An HFModelSource object containing a repo_id
:param download_to: Path to directory that will contain the downloaded model.
Returns: a ModelInstallJob
"""
@abstractmethod @abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source.""" """Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -249,6 +249,9 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job) self._install_jobs.append(install_job)
return install_job return install_job
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
return self._import_from_hf(source, download_path=download_to)
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs return self._install_jobs
@ -641,7 +644,12 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False, inplace=source.inplace or False,
) )
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
download_path: Optional[Path] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests # Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token() source.access_token = source.access_token or HfFolder.get_token()
if not source.access_token: if not source.access_token:
@ -660,9 +668,14 @@ class ModelInstallService(ModelInstallServiceBase):
config=config, config=config,
remote_files=remote_files, remote_files=remote_files,
metadata=metadata, metadata=metadata,
download_path=download_path,
) )
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_from_url(
self,
source: URLModelSource,
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially # URLs from HuggingFace will be handled specially
metadata = None metadata = None
fetcher = None fetcher = None
@ -676,6 +689,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"metadata={metadata}") self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles): if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session) remote_files = metadata.download_urls(session=self._session)
print(remote_files)
else: else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
return self._import_remote_model( return self._import_remote_model(
@ -691,13 +705,14 @@ class ModelInstallService(ModelInstallServiceBase):
remote_files: List[RemoteModelFile], remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
download_path: Optional[Path] = None, # if defined, download only - don't install!
) -> ModelInstallJob: ) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is # Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread. # being held in a daemon thread.
if len(remote_files) == 0: if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found") raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path( destdir = download_path or Path(
mkdtemp( mkdtemp(
dir=self._app_config.models_path, dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX, prefix=TMPDIR_PREFIX,
@ -708,7 +723,7 @@ class ModelInstallService(ModelInstallServiceBase):
source=source, source=source,
config_in=config or {}, config_in=config or {},
source_metadata=metadata, source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0, bytes=0,
total_bytes=0, total_bytes=0,
) )
@ -722,9 +737,10 @@ class ModelInstallService(ModelInstallServiceBase):
root = Path(".") root = Path(".")
subfolder = Path(".") subfolder = Path(".")
# we remember the path up to the top of the tmpdir so that it may be # we remember the path up to the top of the destdir so that it may be
# removed safely at the end of the install process. # removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir install_job._install_tmpdir = destdir
install_job._do_install = download_path is None
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
files_string = "file" if len(remote_files) == 1 else "file" files_string = "file" if len(remote_files) == 1 else "file"
@ -736,7 +752,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.debug(f"Downloading {url} => {path}") self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size install_job.total_bytes += model_file.size
assert hasattr(source, "access_token") assert hasattr(source, "access_token")
dest = tmpdir / path.parent dest = destdir / path.parent
dest.mkdir(parents=True, exist_ok=True) dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob( download_job = DownloadJob(
source=url, source=url,
@ -805,7 +821,8 @@ class ModelInstallService(ModelInstallServiceBase):
# are there any more active jobs left in this task? # are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts): if install_job.downloading and all(x.complete for x in install_job.download_parts):
self._signal_job_downloads_done(install_job) self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job) if install_job._do_install:
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None) self._download_cache.pop(download_job.source, None)

View File

@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum): class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format.""" """Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier - empty str Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16" FP16 = "fp16"
FP32 = "fp32" FP32 = "fp32"
ONNX = "onnx" ONNX = "onnx"

View File

@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
assert s.size is not None assert s.size is not None
files.append( files.append(
RemoteModelFile( RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant), url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename), path=Path(name, s.rfilename),
size=s.size, size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None, sha256=s.lfs.get("sha256") if s.lfs else None,

View File

@ -14,6 +14,7 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ( from invokeai.app.services.model_install import (
HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
@ -21,7 +22,13 @@ from invokeai.app.services.model_install import (
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType from invokeai.backend.model_manager.config import (
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system OS = platform.uname().system
@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus bus = mm2_installer.event_bus
@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
} }
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
job = mm2_installer.download_diffusers_model(source, tmp_path)
mm2_installer.wait_for_installs(timeout=5)
print(job.local_path)
assert job.status == InstallStatus.DOWNLOADS_DONE
assert (tmp_path / "sdxl-turbo").exists()
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source) job = mm2_installer.import_model(source)
@ -327,7 +345,7 @@ def test_other_error_during_install(
}, },
], ],
) )
@pytest.mark.timeout(timeout=40, method="thread") @pytest.mark.timeout(timeout=20, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import.""" """Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params assert "name" in model_params and "type" in model_params