mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more refactoring; HF subfolders not working
This commit is contained in:
parent
911a24479b
commit
2dae5eb7ad
@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
|
|||||||
following initialization pattern:
|
following initialization pattern:
|
||||||
|
|
||||||
```
|
```
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||||
from invokeai.app.services.model_install import ModelInstallService
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
from invokeai.app.services.download import DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = get_config()
|
||||||
config.parse_args()
|
|
||||||
|
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = SqliteDatabase(config, logger)
|
db = SqliteDatabase(config.db_path, logger)
|
||||||
record_store = ModelRecordServiceSQL(db)
|
record_store = ModelRecordServiceSQL(db)
|
||||||
queue = DownloadQueueService()
|
queue = DownloadQueueService()
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
installer = ModelInstallService(app_config=config,
|
installer = ModelInstallService(app_config=config,
|
||||||
record_store=record_store,
|
record_store=record_store,
|
||||||
download_queue=queue
|
download_queue=queue
|
||||||
)
|
)
|
||||||
installer.start()
|
installer.start()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -466,17 +466,14 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_and_cache_ckpt(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: int = 0,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
:param source: A Url or a string that can be converted into one.
|
:param source: A string representing a URL or repo_id.
|
||||||
:param access_token: Optional access token to access restricted resources.
|
|
||||||
|
|
||||||
The model file will be downloaded into the system-wide model cache
|
The model file will be downloaded into the system-wide model cache
|
||||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
|
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
@ -208,26 +208,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
"""Install a model using pattern matching to infer the type of source."""
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
source_obj = self._guess_source(source)
|
||||||
source_obj: Optional[StringLikeSource] = None
|
if isinstance(source_obj, LocalModelSource):
|
||||||
|
source_obj.inplace = inplace
|
||||||
if Path(source).exists(): # A local file or directory
|
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
||||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
source_obj.access_token = access_token
|
||||||
elif match := re.match(hf_repoid_re, source):
|
|
||||||
source_obj = HFModelSource(
|
|
||||||
repo_id=match.group(1),
|
|
||||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
|
||||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
elif re.match(r"^https?://[^/]+", source):
|
|
||||||
source_obj = URLModelSource(
|
|
||||||
url=AnyHttpUrl(source),
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model source: '{source}'")
|
|
||||||
return self.import_model(source_obj, config)
|
return self.import_model(source_obj, config)
|
||||||
|
|
||||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||||
@ -383,37 +369,86 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
escaped_source = slugify(str(source))
|
escaped_source = slugify(str(source))
|
||||||
return app_config.download_cache_path / escaped_source
|
return app_config.download_cache_path / escaped_source
|
||||||
|
|
||||||
def download_and_cache_ckpt(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: int = 0,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Download the model file located at source to the models cache and return its Path."""
|
"""Download the model file located at source to the models cache and return its Path."""
|
||||||
model_path = self._download_cache_path(source, self._app_config)
|
model_path = self._download_cache_path(str(source), self._app_config)
|
||||||
|
|
||||||
# We expect the cache directory to contain one and only one downloaded file.
|
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||||
# We don't know the file's name in advance, as it is set by the download
|
# We don't know the file's name in advance, as it is set by the download
|
||||||
# content-disposition header.
|
# content-disposition header.
|
||||||
if model_path.exists():
|
if model_path.exists():
|
||||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
contents: List[Path] = list(model_path.iterdir())
|
||||||
if len(contents) > 0:
|
if len(contents) > 0:
|
||||||
return contents[0]
|
return contents[0]
|
||||||
|
|
||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
job = self._download_queue.download(
|
model_source = self._guess_source(source)
|
||||||
source=AnyHttpUrl(str(source)),
|
remote_files, _ = self._remote_files_from_source(model_source)
|
||||||
|
job = self._download_queue.multifile_download(
|
||||||
|
parts=remote_files,
|
||||||
dest=model_path,
|
dest=model_path,
|
||||||
access_token=access_token,
|
|
||||||
on_progress=TqdmProgress().update,
|
|
||||||
)
|
)
|
||||||
self._download_queue.wait_for_job(job, timeout)
|
files_string = "file" if len(remote_files) == 1 else "file"
|
||||||
|
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||||
|
self._download_queue.wait_for_job(job)
|
||||||
if job.complete:
|
if job.complete:
|
||||||
assert job.download_path is not None
|
assert job.download_path is not None
|
||||||
return job.download_path
|
return job.download_path
|
||||||
else:
|
else:
|
||||||
raise Exception(job.error)
|
raise Exception(job.error)
|
||||||
|
|
||||||
|
def _remote_files_from_source(
|
||||||
|
self, source: ModelSource
|
||||||
|
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
||||||
|
metadata = None
|
||||||
|
if isinstance(source, HFModelSource):
|
||||||
|
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
return metadata.download_urls(
|
||||||
|
variant=source.variant or self._guess_variant(),
|
||||||
|
subfolder=source.subfolder,
|
||||||
|
session=self._session,
|
||||||
|
), metadata
|
||||||
|
|
||||||
|
if isinstance(source, URLModelSource):
|
||||||
|
try:
|
||||||
|
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||||
|
kwargs: dict[str, Any] = {"session": self._session}
|
||||||
|
metadata = fetcher(**kwargs).from_url(source.url)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
return metadata.download_urls(session=self._session), metadata
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
||||||
|
|
||||||
|
raise Exception(f"No files associated with {source}")
|
||||||
|
|
||||||
|
def _guess_source(self, source: str) -> ModelSource:
|
||||||
|
"""Turn a source string into a ModelSource object."""
|
||||||
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
|
source_obj: Optional[StringLikeSource] = None
|
||||||
|
|
||||||
|
if Path(source).exists(): # A local file or directory
|
||||||
|
source_obj = LocalModelSource(path=Path(source))
|
||||||
|
elif match := re.match(hf_repoid_re, source):
|
||||||
|
source_obj = HFModelSource(
|
||||||
|
repo_id=match.group(1),
|
||||||
|
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||||
|
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||||
|
)
|
||||||
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
|
source_obj = URLModelSource(
|
||||||
|
url=AnyHttpUrl(source),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
|
return source_obj
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the installer threads
|
# Internal functions that manage the installer threads
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
@ -650,18 +685,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
if source.access_token is None:
|
||||||
if not source.access_token:
|
source.access_token = HfFolder.get_token()
|
||||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
|
|
||||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
||||||
remote_files = metadata.download_urls(
|
|
||||||
variant=source.variant or self._guess_variant(),
|
|
||||||
subfolder=source.subfolder,
|
|
||||||
session=self._session,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@ -674,21 +700,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source: URLModelSource,
|
source: URLModelSource,
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# URLs from HuggingFace will be handled specially
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
metadata = None
|
|
||||||
fetcher = None
|
|
||||||
try:
|
|
||||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
kwargs: dict[str, Any] = {"session": self._session}
|
|
||||||
if fetcher is not None:
|
|
||||||
metadata = fetcher(**kwargs).from_url(source.url)
|
|
||||||
self._logger.debug(f"metadata={metadata}")
|
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
|
||||||
else:
|
|
||||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@ -733,26 +745,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
root = Path(".")
|
root = Path(".")
|
||||||
subfolder = Path(".")
|
subfolder = Path(".")
|
||||||
|
|
||||||
# we remember the path up to the top of the destdir so that it may be
|
|
||||||
# removed safely at the end of the install process.
|
|
||||||
install_job._install_tmpdir = destdir
|
|
||||||
|
|
||||||
parts: List[RemoteModelFile] = []
|
parts: List[RemoteModelFile] = []
|
||||||
for model_file in remote_files:
|
for model_file in remote_files:
|
||||||
assert install_job.total_bytes is not None
|
assert install_job.total_bytes is not None
|
||||||
assert model_file.size is not None
|
assert model_file.size is not None
|
||||||
install_job.total_bytes += model_file.size
|
install_job.total_bytes += model_file.size
|
||||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
||||||
multifile_job = self._download_queue.multifile_download(
|
multifile_job = self._multifile_download(
|
||||||
parts=parts,
|
parts=parts,
|
||||||
dest=destdir,
|
dest=destdir,
|
||||||
access_token=source.access_token,
|
access_token=source.access_token,
|
||||||
submit_job=False,
|
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||||
on_start=self._download_started_callback,
|
|
||||||
on_progress=self._download_progress_callback,
|
|
||||||
on_complete=self._download_complete_callback,
|
|
||||||
on_error=self._download_error_callback,
|
|
||||||
on_cancelled=self._download_cancelled_callback,
|
|
||||||
)
|
)
|
||||||
self._download_cache[multifile_job.id] = install_job
|
self._download_cache[multifile_job.id] = install_job
|
||||||
install_job._download_job = multifile_job
|
install_job._download_job = multifile_job
|
||||||
@ -772,6 +775,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
def _multifile_download(
|
||||||
|
self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
return self._download_queue.multifile_download(
|
||||||
|
parts=parts,
|
||||||
|
dest=dest,
|
||||||
|
access_token=access_token,
|
||||||
|
submit_job=submit_job,
|
||||||
|
on_start=self._download_started_callback,
|
||||||
|
on_progress=self._download_progress_callback,
|
||||||
|
on_complete=self._download_complete_callback,
|
||||||
|
on_error=self._download_error_callback,
|
||||||
|
on_cancelled=self._download_cancelled_callback,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Callbacks are executed by the download queue in a separate thread
|
# Callbacks are executed by the download queue in a separate thread
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -875,10 +893,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
assert job.config_out is not None
|
assert job.config_out is not None
|
||||||
key = job.config_out.key
|
key = job.config_out.key
|
||||||
self._event_bus.emit_model_install_completed(source=str(job.source),
|
self._event_bus.emit_model_install_completed(
|
||||||
key=key,
|
source=str(job.source), key=key, id=job.id, total_bytes=job.bytes
|
||||||
id=job.id,
|
)
|
||||||
total_bytes=job.bytes)
|
|
||||||
|
|
||||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||||
|
@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.config import (
|
|
||||||
ControlAdapterDefaultSettings,
|
ControlAdapterDefaultSettings,
|
||||||
MainModelDefaultSettings,
|
MainModelDefaultSettings,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
|
@ -222,7 +222,7 @@ def test_delete_register(
|
|||||||
store.get_model(key)
|
store.get_model(key)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||||
|
|
||||||
@ -285,9 +285,8 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
# TODO: Test subfolder download
|
|
||||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||||
|
|
||||||
bus = mm2_installer.event_bus
|
bus = mm2_installer.event_bus
|
||||||
@ -323,6 +322,7 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
|||||||
assert job.total_bytes == completed_events[0].payload["total_bytes"]
|
assert job.total_bytes == completed_events[0].payload["total_bytes"]
|
||||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
|
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
|
||||||
|
|
||||||
|
|
||||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||||
job = mm2_installer.import_model(source)
|
job = mm2_installer.import_model(source)
|
||||||
@ -371,7 +371,7 @@ def test_other_error_during_install(
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
assert "name" in model_params and "type" in model_params
|
assert "name" in model_params and "type" in model_params
|
||||||
@ -387,7 +387,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
|||||||
}
|
}
|
||||||
assert "repo_id" in model_params
|
assert "repo_id" in model_params
|
||||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
mm2_installer.wait_for_job(install_job1, timeout=10)
|
||||||
if model_params["type"] != "embedding":
|
if model_params["type"] != "embedding":
|
||||||
assert install_job1.errored
|
assert install_job1.errored
|
||||||
assert install_job1.error_type == "InvalidModelConfigException"
|
assert install_job1.error_type == "InvalidModelConfigException"
|
||||||
@ -396,6 +396,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
|||||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||||
|
|
||||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
mm2_installer.wait_for_job(install_job2, timeout=10)
|
||||||
assert install_job2.complete
|
assert install_job2.complete
|
||||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||||
|
Loading…
Reference in New Issue
Block a user