mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more refactoring; HF subfolders not working
This commit is contained in:
parent
911a24479b
commit
2dae5eb7ad
@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
|
||||
following initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
config = get_config()
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config, logger)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
|
@ -466,17 +466,14 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache_ckpt(
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
source: str,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A Url or a string that can be converted into one.
|
||||
:param access_token: Optional access token to access restricted resources.
|
||||
:param source: A string representing a URL or repo_id.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
@ -208,26 +208,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
"""Install a model using pattern matching to infer the type of source."""
|
||||
source_obj = self._guess_source(source)
|
||||
if isinstance(source_obj, LocalModelSource):
|
||||
source_obj.inplace = inplace
|
||||
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
@ -383,37 +369,86 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
escaped_source = slugify(str(source))
|
||||
return app_config.download_cache_path / escaped_source
|
||||
|
||||
def download_and_cache_ckpt(
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
source: str,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_path = self._download_cache_path(source, self._app_config)
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
# content-disposition header.
|
||||
if model_path.exists():
|
||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||
contents: List[Path] = list(model_path.iterdir())
|
||||
if len(contents) > 0:
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
job = self._download_queue.download(
|
||||
source=AnyHttpUrl(str(source)),
|
||||
model_source = self._guess_source(source)
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._download_queue.multifile_download(
|
||||
parts=remote_files,
|
||||
dest=model_path,
|
||||
access_token=access_token,
|
||||
on_progress=TqdmProgress().update,
|
||||
)
|
||||
self._download_queue.wait_for_job(job, timeout)
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||
self._download_queue.wait_for_job(job)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
return job.download_path
|
||||
else:
|
||||
raise Exception(job.error)
|
||||
|
||||
def _remote_files_from_source(
|
||||
self, source: ModelSource
|
||||
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
||||
metadata = None
|
||||
if isinstance(source, HFModelSource):
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
return metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
), metadata
|
||||
|
||||
if isinstance(source, URLModelSource):
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
return metadata.download_urls(session=self._session), metadata
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
||||
|
||||
raise Exception(f"No files associated with {source}")
|
||||
|
||||
def _guess_source(self, source: str) -> ModelSource:
|
||||
"""Turn a source string into a ModelSource object."""
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
return source_obj
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@ -650,18 +685,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
if not source.access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
)
|
||||
|
||||
if source.access_token is None:
|
||||
source.access_token = HfFolder.get_token()
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -674,21 +700,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
# URLs from HuggingFace will be handled specially
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
else:
|
||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -733,26 +745,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
|
||||
# we remember the path up to the top of the destdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = destdir
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert install_job.total_bytes is not None
|
||||
assert model_file.size is not None
|
||||
install_job.total_bytes += model_file.size
|
||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
||||
multifile_job = self._download_queue.multifile_download(
|
||||
multifile_job = self._multifile_download(
|
||||
parts=parts,
|
||||
dest=destdir,
|
||||
access_token=source.access_token,
|
||||
submit_job=False,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||
)
|
||||
self._download_cache[multifile_job.id] = install_job
|
||||
install_job._download_job = multifile_job
|
||||
@ -772,6 +775,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||
return size
|
||||
|
||||
def _multifile_download(
|
||||
self, parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True
|
||||
) -> MultiFileDownloadJob:
|
||||
return self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
dest=dest,
|
||||
access_token=access_token,
|
||||
submit_job=submit_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks are executed by the download queue in a separate thread
|
||||
# ------------------------------------------------------------------
|
||||
@ -875,10 +893,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(source=str(job.source),
|
||||
key=key,
|
||||
id=job.id,
|
||||
total_bytes=job.bytes)
|
||||
self._event_bus.emit_model_install_completed(
|
||||
source=str(job.source), key=key, id=job.id, total_bytes=job.bytes
|
||||
)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||
|
@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
@ -222,7 +222,7 @@ def test_delete_register(
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
@ -253,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
@ -285,9 +285,8 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
# TODO: Test subfolder download
|
||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
@ -323,6 +322,7 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
assert job.total_bytes == completed_events[0].payload["total_bytes"]
|
||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
@ -371,7 +371,7 @@ def test_other_error_during_install(
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
@ -387,7 +387,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
}
|
||||
assert "repo_id" in model_params
|
||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=10)
|
||||
if model_params["type"] != "embedding":
|
||||
assert install_job1.errored
|
||||
assert install_job1.error_type == "InvalidModelConfigException"
|
||||
@ -396,6 +396,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
|
||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=10)
|
||||
assert install_job2.complete
|
||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
|
Loading…
Reference in New Issue
Block a user