mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
check model hash before and after moving in filesystem
This commit is contained in:
@ -7,7 +7,6 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from processor import Invoker
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@ -28,6 +27,9 @@ from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
# processor is giving circular import errors
|
||||
# from .processor import Invoker
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
|
||||
@ -345,9 +347,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self._event_bus:
|
||||
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
||||
# TO DO - Pass storage service rather than letting loader create storage service
|
||||
self._loader = ModelLoad(config, **kwargs)
|
||||
|
||||
def start(self, invoker: Invoker):
|
||||
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
|
||||
"""Call automatically at process start."""
|
||||
self._loader.installer.scan_models_directory() # synchronize new/deleted models found in models directory
|
||||
|
||||
|
@ -458,8 +458,11 @@ class ModelInstall(ModelInstallBase):
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_hash = self.hash(new_path)
|
||||
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
return self._register(
|
||||
self._move_model(model_path, dest_path),
|
||||
new_path,
|
||||
info,
|
||||
)
|
||||
|
||||
@ -476,10 +479,7 @@ class ModelInstall(ModelInstallBase):
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
self._logger.warning("Use shutil.move(), not Path.replace() here; hash before and after move")
|
||||
# BUG! This won't work across filesystems.
|
||||
# Rehash before and after moving.
|
||||
return old_path.replace(new_path)
|
||||
return move(old_path, new_path)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
@ -606,6 +606,8 @@ class ModelInstall(ModelInstallBase):
|
||||
f"{old_path.name} is not in the right directory for a model of its type. Moving to {new_path}."
|
||||
)
|
||||
model.path = self._move_model(old_path, new_path).as_posix()
|
||||
new_hash = self.hash(model.path)
|
||||
assert new_hash == key, f"{model.name}: Model hash changed during installation, possibly corrupted."
|
||||
self._store.update_model(key, model)
|
||||
return model
|
||||
|
||||
|
@ -122,24 +122,18 @@ class ModelLoad(ModelLoadBase):
|
||||
_cache_keys: dict
|
||||
_models_file: Path
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, event_handlers: List[DownloadEventHandler] = []):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
):
|
||||
"""
|
||||
Initialize ModelLoad object.
|
||||
|
||||
:param config: The app's InvokeAIAppConfig object.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
models_file = config.model_conf_path
|
||||
else:
|
||||
models_file = config.root_path / "configs/models3.yaml"
|
||||
try:
|
||||
store = get_config_store(models_file)
|
||||
except ConfigFileVersionMismatchException:
|
||||
migrate_models_store(config)
|
||||
store = get_config_store(models_file)
|
||||
|
||||
if not store:
|
||||
raise ValueError(f"Invalid model configuration file: {models_file}")
|
||||
store = store or self._create_store(config)
|
||||
|
||||
self._app_config = config
|
||||
self._store = store
|
||||
@ -151,13 +145,13 @@ class ModelLoad(ModelLoadBase):
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
self._cache_keys = dict()
|
||||
self._models_file = models_file
|
||||
self._models_file = config.model_conf_path
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
precision = choose_precision(device) if config.precision == "auto" else config.precision
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
self._logger.info(f"Using models database {models_file}")
|
||||
self._logger.info(f"Using models database {self._models_file}")
|
||||
self._logger.info(f"Rendering device = {device} ({device_name})")
|
||||
self._logger.info(f"Maximum RAM cache size: {config.ram}")
|
||||
self._logger.info(f"Maximum VRAM cache size: {config.vram}")
|
||||
@ -172,13 +166,27 @@ class ModelLoad(ModelLoadBase):
|
||||
logger=self._logger,
|
||||
)
|
||||
|
||||
def _create_store(self, config: InvokeAIAppConfig) -> ModelConfigStore:
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
models_file = config.model_conf_path
|
||||
else:
|
||||
models_file = config.root_path / "configs/models.yaml"
|
||||
try:
|
||||
store = get_config_store(models_file)
|
||||
except ConfigFileVersionMismatchException:
|
||||
migrate_models_store(config)
|
||||
store = get_config_store(models_file)
|
||||
if not store:
|
||||
raise ValueError(f"Invalid model configuration file: {models_file}")
|
||||
return store
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the ModelConfigStore instance used by this class."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def precision(self) -> torch.fp32:
|
||||
def precision(self) -> torch.dtype:
|
||||
"""Return torch.fp16 or torch.fp32."""
|
||||
return self._cache.precision
|
||||
|
||||
|
@ -99,13 +99,14 @@ sd-1/controlnet/ip2p:
|
||||
sd-1/embedding/EasyNegative:
|
||||
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||
recommended: True
|
||||
sd-1/embedding/ahx-beta-453407d:
|
||||
source: sd-concepts-library/ahx-beta-453407d
|
||||
description: A textual inversion to use in the negative prompt to reduce bad anatomy
|
||||
sd-1/lora/LowRA:
|
||||
source: https://civitai.com/api/download/models/63006
|
||||
recommended: True
|
||||
description: An embedding that helps generate low-light images
|
||||
sd-1/lora/Ink scenery:
|
||||
source: https://civitai.com/api/download/models/83390
|
||||
description: Generate india ink-like landscapes
|
||||
sd-1/ip_adapter/ip_adapter_sd15:
|
||||
source: InvokeAI/ip_adapter_sd15
|
||||
recommended: True
|
||||
|
Reference in New Issue
Block a user