check model hash before and after moving in filesystem

This commit is contained in:
Lincoln Stein
2023-10-04 09:40:15 -04:00
parent 16ec7a323b
commit a180c0f241
4 changed files with 39 additions and 25 deletions

View File

@ -7,7 +7,6 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from processor import Invoker
from pydantic import Field, parse_obj_as
from pydantic.networks import AnyHttpUrl
@ -28,6 +27,9 @@ from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
# processor is giving circular import errors
# from .processor import Invoker
from .config import InvokeAIAppConfig
from .events import EventServiceBase
@ -345,9 +347,10 @@ class ModelManagerService(ModelManagerServiceBase):
kwargs: Dict[str, Any] = {}
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
# TO DO - Pass storage service rather than letting loader create storage service
self._loader = ModelLoad(config, **kwargs)
def start(self, invoker: Invoker):
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
"""Call automatically at process start."""
self._loader.installer.scan_models_directory() # synchronize new/deleted models found in models directory

View File

@ -458,8 +458,11 @@ class ModelInstall(ModelInstallBase):
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
new_path = self._move_model(model_path, dest_path)
new_hash = self.hash(new_path)
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
self._move_model(model_path, dest_path),
new_path,
info,
)
@ -476,10 +479,7 @@ class ModelInstall(ModelInstallBase):
if not path.exists():
new_path = path
counter += 1
self._logger.warning("Use shutil.move(), not Path.replace() here; hash before and after move")
# BUG! This won't work across filesystems.
# Rehash before and after moving.
return old_path.replace(new_path)
return move(old_path, new_path)
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
@ -606,6 +606,8 @@ class ModelInstall(ModelInstallBase):
f"{old_path.name} is not in the right directory for a model of its type. Moving to {new_path}."
)
model.path = self._move_model(old_path, new_path).as_posix()
new_hash = self.hash(model.path)
assert new_hash == key, f"{model.name}: Model hash changed during installation, possibly corrupted."
self._store.update_model(key, model)
return model

View File

@ -122,24 +122,18 @@ class ModelLoad(ModelLoadBase):
_cache_keys: dict
_models_file: Path
def __init__(self, config: InvokeAIAppConfig, event_handlers: List[DownloadEventHandler] = []):
def __init__(
self,
config: InvokeAIAppConfig,
store: Optional[ModelConfigStore] = None,
event_handlers: List[DownloadEventHandler] = [],
):
"""
Initialize ModelLoad object.
:param config: The app's InvokeAIAppConfig object.
"""
if config.model_conf_path and config.model_conf_path.exists():
models_file = config.model_conf_path
else:
models_file = config.root_path / "configs/models3.yaml"
try:
store = get_config_store(models_file)
except ConfigFileVersionMismatchException:
migrate_models_store(config)
store = get_config_store(models_file)
if not store:
raise ValueError(f"Invalid model configuration file: {models_file}")
store = store or self._create_store(config)
self._app_config = config
self._store = store
@ -151,13 +145,13 @@ class ModelLoad(ModelLoadBase):
event_handlers=event_handlers,
)
self._cache_keys = dict()
self._models_file = models_file
self._models_file = config.model_conf_path
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
precision = choose_precision(device) if config.precision == "auto" else config.precision
dtype = torch.float32 if precision == "float32" else torch.float16
self._logger.info(f"Using models database {models_file}")
self._logger.info(f"Using models database {self._models_file}")
self._logger.info(f"Rendering device = {device} ({device_name})")
self._logger.info(f"Maximum RAM cache size: {config.ram}")
self._logger.info(f"Maximum VRAM cache size: {config.vram}")
@ -172,13 +166,27 @@ class ModelLoad(ModelLoadBase):
logger=self._logger,
)
def _create_store(self, config: InvokeAIAppConfig) -> ModelConfigStore:
if config.model_conf_path and config.model_conf_path.exists():
models_file = config.model_conf_path
else:
models_file = config.root_path / "configs/models.yaml"
try:
store = get_config_store(models_file)
except ConfigFileVersionMismatchException:
migrate_models_store(config)
store = get_config_store(models_file)
if not store:
raise ValueError(f"Invalid model configuration file: {models_file}")
return store
@property
def store(self) -> ModelConfigStore:
"""Return the ModelConfigStore instance used by this class."""
return self._store
@property
def precision(self) -> torch.fp32:
def precision(self) -> torch.dtype:
"""Return torch.fp16 or torch.fp32."""
return self._cache.precision

View File

@ -99,13 +99,14 @@ sd-1/controlnet/ip2p:
sd-1/embedding/EasyNegative:
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
sd-1/embedding/ahx-beta-453407d:
source: sd-concepts-library/ahx-beta-453407d
description: A textual inversion to use in the negative prompt to reduce bad anatomy
sd-1/lora/LowRA:
source: https://civitai.com/api/download/models/63006
recommended: True
description: An embedding that helps generate low-light images
sd-1/lora/Ink scenery:
source: https://civitai.com/api/download/models/83390
description: Generate india ink-like landscapes
sd-1/ip_adapter/ip_adapter_sd15:
source: InvokeAI/ip_adapter_sd15
recommended: True