mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bad implementation of diffusers folder download
This commit is contained in:
parent
f211c95dbc
commit
b48d4a049d
@ -177,6 +177,7 @@ class ModelInstallJob(BaseModel):
|
|||||||
)
|
)
|
||||||
# internal flags and transitory settings
|
# internal flags and transitory settings
|
||||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||||
|
_do_install: Optional[bool] = PrivateAttr(default=True)
|
||||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def set_error(self, e: Exception) -> None:
|
def set_error(self, e: Exception) -> None:
|
||||||
@ -407,6 +408,21 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download_diffusers_model(
|
||||||
|
self,
|
||||||
|
source: HFModelSource,
|
||||||
|
download_to: Path,
|
||||||
|
) -> ModelInstallJob:
|
||||||
|
"""
|
||||||
|
Download, but do not install, a diffusers model.
|
||||||
|
|
||||||
|
:param source: An HFModelSource object containing a repo_id
|
||||||
|
:param download_to: Path to directory that will contain the downloaded model.
|
||||||
|
|
||||||
|
Returns: a ModelInstallJob
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||||
|
@ -249,6 +249,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._install_jobs.append(install_job)
|
self._install_jobs.append(install_job)
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
|
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
|
||||||
|
return self._import_from_hf(source, download_path=download_to)
|
||||||
|
|
||||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
@ -641,7 +644,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
inplace=source.inplace or False,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_hf(
|
||||||
|
self,
|
||||||
|
source: HFModelSource,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
download_path: Optional[Path] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
source.access_token = source.access_token or HfFolder.get_token()
|
||||||
if not source.access_token:
|
if not source.access_token:
|
||||||
@ -660,9 +668,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config=config,
|
config=config,
|
||||||
remote_files=remote_files,
|
remote_files=remote_files,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
download_path=download_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(
|
||||||
|
self,
|
||||||
|
source: URLModelSource,
|
||||||
|
config: Optional[Dict[str, Any]],
|
||||||
|
) -> ModelInstallJob:
|
||||||
# URLs from HuggingFace will be handled specially
|
# URLs from HuggingFace will be handled specially
|
||||||
metadata = None
|
metadata = None
|
||||||
fetcher = None
|
fetcher = None
|
||||||
@ -676,6 +689,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.debug(f"metadata={metadata}")
|
self._logger.debug(f"metadata={metadata}")
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
|
print(remote_files)
|
||||||
else:
|
else:
|
||||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
@ -691,13 +705,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
|
download_path: Optional[Path] = None, # if defined, download only - don't install!
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||||
# being held in a daemon thread.
|
# being held in a daemon thread.
|
||||||
if len(remote_files) == 0:
|
if len(remote_files) == 0:
|
||||||
raise ValueError(f"{source}: No downloadable files found")
|
raise ValueError(f"{source}: No downloadable files found")
|
||||||
tmpdir = Path(
|
destdir = download_path or Path(
|
||||||
mkdtemp(
|
mkdtemp(
|
||||||
dir=self._app_config.models_path,
|
dir=self._app_config.models_path,
|
||||||
prefix=TMPDIR_PREFIX,
|
prefix=TMPDIR_PREFIX,
|
||||||
@ -708,7 +723,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
config_in=config or {},
|
config_in=config or {},
|
||||||
source_metadata=metadata,
|
source_metadata=metadata,
|
||||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||||
bytes=0,
|
bytes=0,
|
||||||
total_bytes=0,
|
total_bytes=0,
|
||||||
)
|
)
|
||||||
@ -722,9 +737,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
root = Path(".")
|
root = Path(".")
|
||||||
subfolder = Path(".")
|
subfolder = Path(".")
|
||||||
|
|
||||||
# we remember the path up to the top of the tmpdir so that it may be
|
# we remember the path up to the top of the destdir so that it may be
|
||||||
# removed safely at the end of the install process.
|
# removed safely at the end of the install process.
|
||||||
install_job._install_tmpdir = tmpdir
|
install_job._install_tmpdir = destdir
|
||||||
|
install_job._do_install = download_path is None
|
||||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||||
|
|
||||||
files_string = "file" if len(remote_files) == 1 else "file"
|
files_string = "file" if len(remote_files) == 1 else "file"
|
||||||
@ -736,7 +752,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.debug(f"Downloading {url} => {path}")
|
self._logger.debug(f"Downloading {url} => {path}")
|
||||||
install_job.total_bytes += model_file.size
|
install_job.total_bytes += model_file.size
|
||||||
assert hasattr(source, "access_token")
|
assert hasattr(source, "access_token")
|
||||||
dest = tmpdir / path.parent
|
dest = destdir / path.parent
|
||||||
dest.mkdir(parents=True, exist_ok=True)
|
dest.mkdir(parents=True, exist_ok=True)
|
||||||
download_job = DownloadJob(
|
download_job = DownloadJob(
|
||||||
source=url,
|
source=url,
|
||||||
@ -805,7 +821,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# are there any more active jobs left in this task?
|
# are there any more active jobs left in this task?
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||||
self._signal_job_downloads_done(install_job)
|
self._signal_job_downloads_done(install_job)
|
||||||
self._put_in_queue(install_job)
|
if install_job._do_install:
|
||||||
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._download_cache.pop(download_job.source, None)
|
self._download_cache.pop(download_job.source, None)
|
||||||
|
@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
|
|||||||
class ModelRepoVariant(str, Enum):
|
class ModelRepoVariant(str, Enum):
|
||||||
"""Various hugging face variants on the diffusers format."""
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
Default = "" # model files without "fp16" or other qualifier
|
||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
FP32 = "fp32"
|
FP32 = "fp32"
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
|
@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
assert s.size is not None
|
assert s.size is not None
|
||||||
files.append(
|
files.append(
|
||||||
RemoteModelFile(
|
RemoteModelFile(
|
||||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||||
path=Path(name, s.rfilename),
|
path=Path(name, s.rfilename),
|
||||||
size=s.size,
|
size=s.size,
|
||||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||||
|
@ -14,6 +14,7 @@ from pydantic_core import Url
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.model_install import (
|
from invokeai.app.services.model_install import (
|
||||||
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
LocalModelSource,
|
LocalModelSource,
|
||||||
ModelInstallJob,
|
ModelInstallJob,
|
||||||
@ -21,7 +22,13 @@ from invokeai.app.services.model_install import (
|
|||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||||
|
|
||||||
bus = mm2_installer.event_bus
|
bus = mm2_installer.event_bus
|
||||||
@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
|
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
|
||||||
|
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||||
|
job = mm2_installer.download_diffusers_model(source, tmp_path)
|
||||||
|
mm2_installer.wait_for_installs(timeout=5)
|
||||||
|
print(job.local_path)
|
||||||
|
assert job.status == InstallStatus.DOWNLOADS_DONE
|
||||||
|
assert (tmp_path / "sdxl-turbo").exists()
|
||||||
|
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||||
job = mm2_installer.import_model(source)
|
job = mm2_installer.import_model(source)
|
||||||
@ -327,7 +345,7 @@ def test_other_error_during_install(
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.timeout(timeout=40, method="thread")
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
assert "name" in model_params and "type" in model_params
|
assert "name" in model_params and "type" in model_params
|
||||||
|
Loading…
Reference in New Issue
Block a user