mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make model key assignment deterministic
- When installing, model keys are now calculated from the model contents. - .safetensors, .ckpt and other single file models are hashed with sha1 - The contents of diffusers directories are hashed using imohash (faster) fixup yaml->sql db migration script to assign deterministic key - this commit also detects and assigns the correct image encoder for ip adapter models.
This commit is contained in:
parent
d16b9a25bb
commit
64f8535ef5
@ -7,7 +7,6 @@ import time
|
|||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from random import randbytes
|
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@ -526,9 +525,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
setattr(info, key, value)
|
setattr(info, key, value)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def _create_key(self) -> str:
|
|
||||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -536,6 +532,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
|
||||||
config["key"] = config.get("key", self._create_key())
|
config["key"] = config.get("key", self._create_key())
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
|
override_key: Optional[str] = config.get("key") if config else None
|
||||||
|
|
||||||
|
assert info.original_hash # always assigned by probe()
|
||||||
|
info.key = override_key or info.original_hash
|
||||||
|
|
||||||
model_path = model_path.absolute()
|
model_path = model_path.absolute()
|
||||||
if model_path.is_relative_to(self.app_config.models_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from hashlib import sha1
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -78,14 +77,22 @@ class MigrateModelYamlToDb1:
|
|||||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(model_key, str)
|
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
stanza["type"] = ModelType(model_type)
|
stanza["type"] = ModelType(model_type)
|
||||||
stanza["name"] = model_name
|
stanza["name"] = model_name
|
||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
stanza["current_hash"] = hash
|
stanza["current_hash"] = hash
|
||||||
|
new_key = hash # deterministic key assignment
|
||||||
|
|
||||||
|
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||||
|
if stanza["type"] == ModelType.IPAdapter:
|
||||||
|
try:
|
||||||
|
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||||
|
self.config.models_path / stanza.path
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
|
|
||||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
|||||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
self._update_model(key, new_config)
|
self._update_model(key, new_config)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||||
self._add_model(new_key, new_config)
|
self._add_model(new_key, new_config)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
|
|||||||
)
|
)
|
||||||
except sqlite3.IntegrityError as exc:
|
except sqlite3.IntegrityError as exc:
|
||||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||||
|
|
||||||
|
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||||
|
with open(model_path / "image_encoder.txt") as f:
|
||||||
|
encoder = f.read()
|
||||||
|
return encoder.strip()
|
||||||
|
@ -28,14 +28,28 @@ class FastModelHash(object):
|
|||||||
"""
|
"""
|
||||||
model_location = Path(model_location)
|
model_location = Path(model_location)
|
||||||
if model_location.is_file():
|
if model_location.is_file():
|
||||||
return cls._hash_file(model_location)
|
return cls._hash_file_sha1(model_location)
|
||||||
elif model_location.is_dir():
|
elif model_location.is_dir():
|
||||||
return cls._hash_dir(model_location)
|
return cls._hash_dir(model_location)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str:
|
||||||
|
"""
|
||||||
|
Compute full sha1 hash over a single file and return its hexdigest.
|
||||||
|
|
||||||
|
:param model_location: Path to the model file
|
||||||
|
"""
|
||||||
|
BLOCK_SIZE = 65536
|
||||||
|
file_hash = hashlib.sha1()
|
||||||
|
with open(model_location, "rb") as f:
|
||||||
|
data = f.read(BLOCK_SIZE)
|
||||||
|
file_hash.update(data)
|
||||||
|
return file_hash.hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _hash_file_fast(cls, model_location: Union[str, Path]) -> str:
|
||||||
"""
|
"""
|
||||||
Fasthash a single file and return its hexdigest.
|
Fasthash a single file and return its hexdigest.
|
||||||
|
|
||||||
@ -56,7 +70,7 @@ class FastModelHash(object):
|
|||||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||||
continue
|
continue
|
||||||
path = (Path(root) / file).as_posix()
|
path = (Path(root) / file).as_posix()
|
||||||
fast_hash = cls._hash_file(path)
|
fast_hash = cls._hash_file_fast(path)
|
||||||
components.update({path: fast_hash})
|
components.update({path: fast_hash})
|
||||||
|
|
||||||
# hash all the model hashes together, using alphabetic file order
|
# hash all the model hashes together, using alphabetic file order
|
||||||
|
Loading…
Reference in New Issue
Block a user