mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bad implementation of diffusers folder download
This commit is contained in:
parent
f211c95dbc
commit
b48d4a049d
@ -177,6 +177,7 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_do_install: Optional[bool] = PrivateAttr(default=True)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
@ -407,6 +408,21 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_diffusers_model(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
download_to: Path,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Download, but do not install, a diffusers model.
|
||||
|
||||
:param source: An HFModelSource object containing a repo_id
|
||||
:param download_to: Path to directory that will contain the downloaded model.
|
||||
|
||||
Returns: a ModelInstallJob
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
@ -249,6 +249,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
|
||||
return self._import_from_hf(source, download_path=download_to)
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
@ -641,7 +644,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
download_path: Optional[Path] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
if not source.access_token:
|
||||
@ -660,9 +668,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config=config,
|
||||
remote_files=remote_files,
|
||||
metadata=metadata,
|
||||
download_path=download_path,
|
||||
)
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
# URLs from HuggingFace will be handled specially
|
||||
metadata = None
|
||||
fetcher = None
|
||||
@ -676,6 +689,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
print(remote_files)
|
||||
else:
|
||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||
return self._import_remote_model(
|
||||
@ -691,13 +705,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
download_path: Optional[Path] = None, # if defined, download only - don't install!
|
||||
) -> ModelInstallJob:
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
tmpdir = Path(
|
||||
destdir = download_path or Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
prefix=TMPDIR_PREFIX,
|
||||
@ -708,7 +723,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
source_metadata=metadata,
|
||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
@ -722,9 +737,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
|
||||
# we remember the path up to the top of the tmpdir so that it may be
|
||||
# we remember the path up to the top of the destdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = tmpdir
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job._do_install = download_path is None
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
@ -736,7 +752,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.debug(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
dest = tmpdir / path.parent
|
||||
dest = destdir / path.parent
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
download_job = DownloadJob(
|
||||
source=url,
|
||||
@ -805,6 +821,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
if install_job._do_install:
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
|
@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
|
@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
|
@ -14,6 +14,7 @@ from pydantic_core import Url
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
@ -21,7 +22,13 @@ from invokeai.app.services.model_install import (
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
|
||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||
job = mm2_installer.download_diffusers_model(source, tmp_path)
|
||||
mm2_installer.wait_for_installs(timeout=5)
|
||||
print(job.local_path)
|
||||
assert job.status == InstallStatus.DOWNLOADS_DONE
|
||||
assert (tmp_path / "sdxl-turbo").exists()
|
||||
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
@ -327,7 +345,7 @@ def test_other_error_during_install(
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=40, method="thread")
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
|
Loading…
Reference in New Issue
Block a user