mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
30 Commits
v4.0.0rc2
...
bug-instal
Author | SHA1 | Date | |
---|---|---|---|
7d46e8430b | |||
b3f8f22998 | |||
6a75c5ba08 | |||
dcbb1ff894 | |||
5751455618 | |||
a2232c2e09 | |||
30da11998b | |||
1f000306f3 | |||
c778d74a42 | |||
97bcd54408 | |||
62d7e38030 | |||
9b6f3bded9 | |||
aec567179d | |||
28bfc1c935 | |||
39f62ac63c | |||
9d0952c2ef | |||
902e26507d | |||
b83427d7ce | |||
7387b0bdc9 | |||
7ea9cac9a3 | |||
ea5bc94b9c | |||
a1743647b7 | |||
a6d64f69e1 | |||
e74e78894f | |||
71a1740740 | |||
b79f2f337e | |||
a0420d1442 | |||
a17021ba0c | |||
faa1ffb06f | |||
8c04eec210 |
1
.github/workflows/python-tests.yml
vendored
1
.github/workflows/python-tests.yml
vendored
@ -9,6 +9,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
- 'bug-install-job-running-multiple-times'
|
||||||
pull_request:
|
pull_request:
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
|
@ -15,7 +15,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
|
|||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
@ -89,17 +89,32 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.models.search_by_attrs(
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
|
||||||
)
|
|
||||||
assert len(image_encoder_models) == 1
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
ip_adapter_model=self.ip_adapter_model,
|
ip_adapter_model=self.ip_adapter_model,
|
||||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]),
|
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||||
|
found = False
|
||||||
|
while not found:
|
||||||
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
|
)
|
||||||
|
found = len(image_encoder_models) > 0
|
||||||
|
if not found:
|
||||||
|
context.logger.warning(
|
||||||
|
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
||||||
|
)
|
||||||
|
context.logger.warning("Downloading and installing now. This may take a while.")
|
||||||
|
installer = context._services.model_manager.install
|
||||||
|
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||||
|
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
||||||
|
assert len(image_encoder_models) == 1
|
||||||
|
return image_encoder_models[0]
|
||||||
|
@ -133,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._download_cache.clear()
|
self._download_cache.clear()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||||
|
print(f'DEBUG: in _put_in_queue(job={job.id})')
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
self.cancel_job(job)
|
||||||
|
else:
|
||||||
|
print(f'DEBUG: putting {job.id} into the install queue')
|
||||||
|
self._install_queue.put(job)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
@ -218,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if isinstance(source, LocalModelSource):
|
if isinstance(source, LocalModelSource):
|
||||||
install_job = self._import_local_model(source, config)
|
install_job = self._import_local_model(source, config)
|
||||||
self._install_queue.put(install_job) # synchronously install
|
self._put_in_queue(install_job) # synchronously install
|
||||||
elif isinstance(source, HFModelSource):
|
elif isinstance(source, HFModelSource):
|
||||||
install_job = self._import_from_hf(source, config)
|
install_job = self._import_from_hf(source, config)
|
||||||
elif isinstance(source, URLModelSource):
|
elif isinstance(source, URLModelSource):
|
||||||
@ -328,7 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
search = ModelSearch(on_model_found=callback)
|
search = ModelSearch(on_model_found=callback)
|
||||||
self._models_installed.clear()
|
self._models_installed.clear()
|
||||||
@ -342,7 +350,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
models_dir = self.app_config.models_path
|
models_dir = self.app_config.models_path
|
||||||
model_path = Path(model.path)
|
model_path = models_dir / Path(model.path) # handle legacy relative model paths
|
||||||
if model_path.is_relative_to(models_dir):
|
if model_path.is_relative_to(models_dir):
|
||||||
self.unconditionally_delete(key)
|
self.unconditionally_delete(key)
|
||||||
else:
|
else:
|
||||||
@ -350,7 +358,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
model_path = Path(model.path)
|
model_path = self.app_config.models_path / model.path
|
||||||
if model_path.is_dir():
|
if model_path.is_dir():
|
||||||
rmtree(model_path)
|
rmtree(model_path)
|
||||||
else:
|
else:
|
||||||
@ -403,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
done = True
|
done = True
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
print(f'DEBUG: _install_next_item() checking for a job to install')
|
||||||
job = self._install_queue.get(timeout=1)
|
job = self._install_queue.get(timeout=1)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
try:
|
try:
|
||||||
if job.cancelled:
|
if job.cancelled:
|
||||||
@ -432,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
else:
|
else:
|
||||||
key = self.install_path(job.local_path, job.config_in)
|
key = self.install_path(job.local_path, job.config_in)
|
||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
|
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
except InvalidModelConfigException as excp:
|
||||||
@ -492,6 +502,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
for cur_base_model in BaseModelType:
|
for cur_base_model in BaseModelType:
|
||||||
for cur_model_type in ModelType:
|
for cur_model_type in ModelType:
|
||||||
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
||||||
|
if not models_dir.exists():
|
||||||
|
continue
|
||||||
installed.update(self.scan_directory(models_dir))
|
installed.update(self.scan_directory(models_dir))
|
||||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
@ -518,7 +530,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||||
|
|
||||||
if old_path == new_path:
|
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
@ -779,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||||
self._logger.info(f"{download_job.source}: model download complete")
|
self._logger.info(f"{download_job.source}: model download complete")
|
||||||
|
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache.pop(download_job.source, None)
|
||||||
self._download_cache.pop(download_job.source, None)
|
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
# are there any more active jobs left in this task?
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||||
|
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
|
||||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||||
self._install_queue.put(install_job)
|
print(f'DEBUG: putting {install_job.id} into the install queue')
|
||||||
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
@ -828,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||||
self._install_queue.put(install_job)
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
|
@ -11,17 +11,6 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
|||||||
try:
|
try:
|
||||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||||
if not config.ignore_missing_core_models:
|
|
||||||
for model in [
|
|
||||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
|
||||||
"bert-base-uncased",
|
|
||||||
"clip-vit-large-patch14",
|
|
||||||
"sd-vae-ft-mse",
|
|
||||||
"stable-diffusion-2-clip",
|
|
||||||
"stable-diffusion-safety-checker",
|
|
||||||
]:
|
|
||||||
path = config.models_path / f"core/convert/{model}"
|
|
||||||
assert path.exists(), f"{path} is missing"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print()
|
print()
|
||||||
print(f"An exception has occurred: {str(e)}")
|
print(f"An exception has occurred: {str(e)}")
|
||||||
@ -32,10 +21,5 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
|||||||
print(
|
print(
|
||||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
|
|
||||||
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
|
|
||||||
"always come back and install these core models in the future.)"
|
|
||||||
)
|
|
||||||
input("Press any key to continue...")
|
input("Press any key to continue...")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -25,20 +25,20 @@ import npyscreen
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL, ModelMixin
|
from diffusers import ModelMixin
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from huggingface_hub import login as hf_hub_login
|
from huggingface_hub import login as hf_hub_login
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic.error_wrappers import ValidationError
|
from pydantic.error_wrappers import ValidationError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
from invokeai.backend.model_manager import ModelType
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.model_install import addModelsForm
|
from invokeai.frontend.install.model_install import addModelsForm
|
||||||
@ -210,51 +210,15 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def download_conversion_models():
|
def download_safety_checker():
|
||||||
target_dir = config.models_path / "core/convert"
|
target_dir = config.models_path / "core/convert"
|
||||||
kwargs = {} # for future use
|
kwargs = {} # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
|
||||||
|
|
||||||
# bert
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
|
||||||
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
|
|
||||||
|
|
||||||
# sd-1
|
|
||||||
repo_id = "openai/clip-vit-large-patch14"
|
|
||||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
|
|
||||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
|
|
||||||
|
|
||||||
# sd-2
|
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
|
|
||||||
|
|
||||||
# sd-xl - tokenizer_2
|
|
||||||
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|
||||||
_, model_name = repo_id.split("/")
|
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
|
||||||
|
|
||||||
# VAE
|
|
||||||
logger.info("Downloading stable diffusion VAE")
|
|
||||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
|
|
||||||
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
|
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info("Downloading safety checker")
|
logger.info("Downloading safety checker")
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||||
|
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -307,7 +271,7 @@ def download_lama():
|
|||||||
def download_support_models() -> None:
|
def download_support_models() -> None:
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_lama()
|
download_lama()
|
||||||
download_conversion_models()
|
download_safety_checker()
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@ -744,12 +708,7 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
|||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||||
|
|
||||||
dest = root / "models"
|
dest = root / "models"
|
||||||
for model_base in BaseModelType:
|
dest.mkdir(parents=True, exist_ok=True)
|
||||||
for model_type in ModelType:
|
|
||||||
path = dest / model_base.value / model_type.value
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
path = dest / "core"
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -3,9 +3,6 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file as safetensors_load_file
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -37,27 +34,25 @@ class ControlNetLoader(GenericDiffusersLoader):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
assert isinstance(config, CheckpointConfigBase)
|
||||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
config_file = config.config_path
|
||||||
else:
|
|
||||||
assert isinstance(config, CheckpointConfigBase)
|
|
||||||
config_file = config.config_path
|
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
image_size = (
|
||||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
512
|
||||||
else:
|
if config.base == BaseModelType.StableDiffusion1
|
||||||
checkpoint = torch.load(model_path, map_location="cpu")
|
else 768
|
||||||
|
if config.base == BaseModelType.StableDiffusion2
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
else 1024
|
||||||
if "state_dict" in checkpoint:
|
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
|
||||||
convert_controlnet_to_diffusers(
|
|
||||||
model_path,
|
|
||||||
output_path,
|
|
||||||
original_config_file=self._app_config.root_path / config_file,
|
|
||||||
image_size=512,
|
|
||||||
scan_needed=True,
|
|
||||||
from_safetensors=model_path.suffix == ".safetensors",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
|
with open(self._app_config.root_path / config_file, "r") as config_stream:
|
||||||
|
convert_controlnet_to_diffusers(
|
||||||
|
model_path,
|
||||||
|
output_path,
|
||||||
|
original_config_file=config_stream,
|
||||||
|
image_size=image_size,
|
||||||
|
precision=self._torch_dtype,
|
||||||
|
from_safetensors=model_path.suffix == ".safetensors",
|
||||||
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -4,9 +4,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -14,7 +11,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
SchedulerPredictionType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||||
@ -68,27 +65,31 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
variant = config.variant
|
|
||||||
base = config.base
|
base = config.base
|
||||||
pipeline_class = (
|
|
||||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
|
||||||
)
|
|
||||||
|
|
||||||
config_file = config.config_path
|
config_file = config.config_path
|
||||||
|
prediction_type = config.prediction_type.value
|
||||||
|
upcast_attention = config.upcast_attention
|
||||||
|
image_size = (
|
||||||
|
1024
|
||||||
|
if base == BaseModelType.StableDiffusionXL
|
||||||
|
else 768
|
||||||
|
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
|
||||||
|
else 512
|
||||||
|
)
|
||||||
|
|
||||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
model_path,
|
model_path,
|
||||||
output_path,
|
output_path,
|
||||||
model_type=self.model_base_to_model_type[base],
|
model_type=self.model_base_to_model_type[base],
|
||||||
model_version=base,
|
|
||||||
model_variant=variant,
|
|
||||||
original_config_file=self._app_config.root_path / config_file,
|
original_config_file=self._app_config.root_path / config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
|
||||||
pipeline_class=pipeline_class,
|
|
||||||
from_safetensors=model_path.suffix == ".safetensors",
|
from_safetensors=model_path.suffix == ".safetensors",
|
||||||
precision=self._torch_dtype,
|
precision=self._torch_dtype,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
image_size=image_size,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
load_safety_checker=False,
|
load_safety_checker=False,
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -57,12 +57,12 @@ class VAELoader(GenericDiffusersLoader):
|
|||||||
|
|
||||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||||
assert isinstance(ckpt_config, DictConfig)
|
assert isinstance(ckpt_config, DictConfig)
|
||||||
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
vae_config=ckpt_config,
|
vae_config=ckpt_config,
|
||||||
image_size=512,
|
image_size=512,
|
||||||
|
precision=self._torch_dtype,
|
||||||
)
|
)
|
||||||
vae_model.to(self._torch_dtype) # set precision appropriately
|
|
||||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -132,11 +132,12 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = None
|
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||||
if format_type is ModelFormat.Diffusers:
|
if not model_type:
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
if format_type is ModelFormat.Diffusers:
|
||||||
else:
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
else:
|
||||||
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||||
|
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
@ -156,7 +157,7 @@ class ModelProbe(object):
|
|||||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
@ -319,7 +320,7 @@ class ModelProbe(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||||
cls._scan_model(model_path.name, model_path)
|
cls._scan_model(model_path.name, model_path)
|
||||||
model = torch.load(model_path)
|
model = torch.load(model_path)
|
||||||
assert isinstance(model, dict)
|
assert isinstance(model, dict)
|
||||||
|
@ -62,40 +62,72 @@ sd-1/main/trinart_stable_diffusion_v2:
|
|||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/qrcode_monster:
|
sd-1/controlnet/qrcode_monster:
|
||||||
source: monster-labs/control_v1p_sd15_qrcode_monster
|
source: monster-labs/control_v1p_sd15_qrcode_monster
|
||||||
|
description: Controlnet model that generates scannable creative QR codes
|
||||||
subfolder: v2
|
subfolder: v2
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with canny conditioning.
|
||||||
source: lllyasviel/control_v11p_sd15_canny
|
source: lllyasviel/control_v11p_sd15_canny
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/inpaint:
|
sd-1/controlnet/inpaint:
|
||||||
source: lllyasviel/control_v11p_sd15_inpaint
|
source: lllyasviel/control_v11p_sd15_inpaint
|
||||||
|
description: Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version
|
||||||
sd-1/controlnet/mlsd:
|
sd-1/controlnet/mlsd:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version
|
||||||
source: lllyasviel/control_v11p_sd15_mlsd
|
source: lllyasviel/control_v11p_sd15_mlsd
|
||||||
sd-1/controlnet/depth:
|
sd-1/controlnet/depth:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with depth conditioning
|
||||||
source: lllyasviel/control_v11f1p_sd15_depth
|
source: lllyasviel/control_v11f1p_sd15_depth
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/normal_bae:
|
sd-1/controlnet/normal_bae:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with normalbae image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_normalbae
|
source: lllyasviel/control_v11p_sd15_normalbae
|
||||||
sd-1/controlnet/seg:
|
sd-1/controlnet/seg:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with seg image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_seg
|
source: lllyasviel/control_v11p_sd15_seg
|
||||||
sd-1/controlnet/lineart:
|
sd-1/controlnet/lineart:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with lineart image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_lineart
|
source: lllyasviel/control_v11p_sd15_lineart
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/lineart_anime:
|
sd-1/controlnet/lineart_anime:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with anime image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
sd-1/controlnet/openpose:
|
sd-1/controlnet/openpose:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with openpose image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_openpose
|
source: lllyasviel/control_v11p_sd15_openpose
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/scribble:
|
sd-1/controlnet/scribble:
|
||||||
source: lllyasviel/control_v11p_sd15_scribble
|
source: lllyasviel/control_v11p_sd15_scribble
|
||||||
|
description: Controlnet weights trained on sd-1.5 with scribble image conditioning
|
||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/softedge:
|
sd-1/controlnet/softedge:
|
||||||
source: lllyasviel/control_v11p_sd15_softedge
|
source: lllyasviel/control_v11p_sd15_softedge
|
||||||
|
description: Controlnet weights trained on sd-1.5 with soft edge conditioning
|
||||||
sd-1/controlnet/shuffle:
|
sd-1/controlnet/shuffle:
|
||||||
source: lllyasviel/control_v11e_sd15_shuffle
|
source: lllyasviel/control_v11e_sd15_shuffle
|
||||||
|
description: Controlnet weights trained on sd-1.5 with shuffle image conditioning
|
||||||
sd-1/controlnet/tile:
|
sd-1/controlnet/tile:
|
||||||
source: lllyasviel/control_v11f1e_sd15_tile
|
source: lllyasviel/control_v11f1e_sd15_tile
|
||||||
|
description: Controlnet weights trained on sd-1.5 with tiled image conditioning
|
||||||
sd-1/controlnet/ip2p:
|
sd-1/controlnet/ip2p:
|
||||||
source: lllyasviel/control_v11e_sd15_ip2p
|
source: lllyasviel/control_v11e_sd15_ip2p
|
||||||
|
description: Controlnet weights trained on sd-1.5 with ip2p conditioning.
|
||||||
|
sdxl/controlnet/canny-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with canny conditioning.
|
||||||
|
source: diffusers/controlnet-canny-sdxl-1.0
|
||||||
|
recommended: True
|
||||||
|
sdxl/controlnet/depth-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with depth conditioning.
|
||||||
|
source: diffusers/controlnet-depth-sdxl-1.0
|
||||||
|
recommended: True
|
||||||
|
sdxl/controlnet/softedge-dexined-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.
|
||||||
|
source: SargeZT/controlnet-sd-xl-1.0-softedge-dexined
|
||||||
|
sdxl/controlnet/depth-16bit-zoe-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).
|
||||||
|
source: SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe
|
||||||
|
sdxl/controlnet/depth-zoe-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).
|
||||||
|
source: diffusers/controlnet-zoe-depth-sdxl-1.0
|
||||||
sd-1/t2i_adapter/canny-sd15:
|
sd-1/t2i_adapter/canny-sd15:
|
||||||
source: TencentARC/t2iadapter_canny_sd15v2
|
source: TencentARC/t2iadapter_canny_sd15v2
|
||||||
sd-1/t2i_adapter/sketch-sd15:
|
sd-1/t2i_adapter/sketch-sd15:
|
||||||
|
@ -55,14 +55,13 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isCurrentMainModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
const isCurrentMainModelAvailable = currentModel ? mainModels.some((m) => m.key === currentModel.key) : false;
|
||||||
|
|
||||||
if (isCurrentMainModelAvailable) {
|
if (isCurrentMainModelAvailable) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultModel = state.config.sd.defaultModel;
|
const defaultModel = state.config.sd.defaultModel;
|
||||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
const defaultModelInList = defaultModel ? mainModels.find((m) => m.key === defaultModel) : false;
|
||||||
|
|
||||||
if (defaultModelInList) {
|
if (defaultModelInList) {
|
||||||
const result = zParameterModel.safeParse(defaultModelInList);
|
const result = zParameterModel.safeParse(defaultModelInList);
|
||||||
@ -84,7 +83,7 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = zParameterModel.safeParse(models[0]);
|
const result = zParameterModel.safeParse(mainModels[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
||||||
|
@ -33,11 +33,11 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core generation dependencies, pinned for reproducible builds.
|
# Core generation dependencies, pinned for reproducible builds.
|
||||||
"accelerate==0.27.2",
|
"accelerate==0.28.0",
|
||||||
"clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
"diffusers[torch]==0.26.3",
|
"diffusers[torch]==0.27.0",
|
||||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||||
@ -56,7 +56,7 @@ dependencies = [
|
|||||||
# Core application dependencies, pinned for reproducible builds.
|
# Core application dependencies, pinned for reproducible builds.
|
||||||
"fastapi-events==0.11.0",
|
"fastapi-events==0.11.0",
|
||||||
"fastapi==0.110.0",
|
"fastapi==0.110.0",
|
||||||
"huggingface-hub==0.21.3",
|
"huggingface-hub==0.21.4",
|
||||||
"pydantic-settings==2.2.1",
|
"pydantic-settings==2.2.1",
|
||||||
"pydantic==2.6.3",
|
"pydantic==2.6.3",
|
||||||
"python-socketio==5.11.1",
|
"python-socketio==5.11.1",
|
||||||
|
@ -5,10 +5,11 @@ Test the model installer
|
|||||||
import platform
|
import platform
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from pydantic.networks import Url
|
from pydantic_core import Url
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
@ -20,7 +21,7 @@ from invokeai.app.services.model_install import (
|
|||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
@ -53,7 +54,7 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
|
|||||||
|
|
||||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||||
key = None
|
key = None
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises((ValidationError, InvalidModelConfigException)):
|
||||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||||
assert key is None
|
assert key is None
|
||||||
|
|
||||||
@ -263,3 +264,50 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
|||||||
assert job.error
|
assert job.error
|
||||||
assert "NOT FOUND" in job.error
|
assert "NOT FOUND" in job.error
|
||||||
assert job.error_traceback.startswith("Traceback")
|
assert job.error_traceback.startswith("Traceback")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_params",
|
||||||
|
[
|
||||||
|
# SDXL, Lora
|
||||||
|
{
|
||||||
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
|
"name": "test_lora",
|
||||||
|
"type": "embedding",
|
||||||
|
},
|
||||||
|
# SDXL, Lora - incorrect type
|
||||||
|
{
|
||||||
|
"repo_id": "InvokeAI-test/textual_inversion_tests::learned_embeds-steps-1000.safetensors",
|
||||||
|
"name": "test_lora",
|
||||||
|
"type": "lora",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.timeout(timeout=40, method="thread")
|
||||||
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
|
assert "name" in model_params and "type" in model_params
|
||||||
|
config1: Dict[str, Any] = {
|
||||||
|
"name": f"{model_params['name']}_1",
|
||||||
|
"type": model_params["type"],
|
||||||
|
"hash": "placeholder1",
|
||||||
|
}
|
||||||
|
config2: Dict[str, Any] = {
|
||||||
|
"name": f"{model_params['name']}_2",
|
||||||
|
"type": ModelType(model_params["type"]),
|
||||||
|
"hash": "placeholder2",
|
||||||
|
}
|
||||||
|
assert "repo_id" in model_params
|
||||||
|
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||||
|
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||||
|
if model_params["type"] != "embedding":
|
||||||
|
assert install_job1.errored
|
||||||
|
assert install_job1.error_type == "InvalidModelConfigException"
|
||||||
|
return
|
||||||
|
assert install_job1.complete
|
||||||
|
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||||
|
|
||||||
|
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||||
|
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||||
|
assert install_job2.complete
|
||||||
|
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||||
|
@ -33,6 +33,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||||
|
HFTestLoraMetadata,
|
||||||
RepoCivitaiModelMetadata1,
|
RepoCivitaiModelMetadata1,
|
||||||
RepoCivitaiVersionMetadata1,
|
RepoCivitaiVersionMetadata1,
|
||||||
RepoHFMetadata1,
|
RepoHFMetadata1,
|
||||||
@ -300,6 +301,20 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
|
||||||
|
TestAdapter(
|
||||||
|
HFTestLoraMetadata,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sess.mount(
|
||||||
|
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
|
||||||
|
TestAdapter(
|
||||||
|
data,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
|
||||||
|
),
|
||||||
|
)
|
||||||
for root, _, files in os.walk(diffusers_dir):
|
for root, _, files in os.walk(diffusers_dir):
|
||||||
for name in files:
|
for name in files:
|
||||||
path = Path(root, name)
|
path = Path(root, name)
|
||||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user