mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
lstein/doc
...
feat/fast-
Author | SHA1 | Date | |
---|---|---|---|
49fd6e2d0f | |||
aae8bab8f2 |
73
invokeai/backend/model_management/model_hash.py
Normal file
73
invokeai/backend/model_management/model_hash.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
Fast hashing of diffusers and checkpoint-style models.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from invokeai.backend.model_management.model_hash import FastModelHash
|
||||||
|
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||||
|
'a8e693a126ea5b831c96064dc569956f'
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
from imohash import hashfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
|
||||||
|
class FastModelHash(object):
|
||||||
|
"""FastModelHash obect provides one public class method, hash()."""
|
||||||
|
|
||||||
|
# When traversing directories, ignore files smaller than this
|
||||||
|
# minimum value
|
||||||
|
MINIMUM_FILE_SIZE = 100000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||||
|
"""
|
||||||
|
Return hexdigest string for model located at model_location.
|
||||||
|
|
||||||
|
:param model_location: Path to the model
|
||||||
|
"""
|
||||||
|
model_location = Path(model_location)
|
||||||
|
if model_location.is_file():
|
||||||
|
return cls._hash_file(model_location)
|
||||||
|
elif model_location.is_dir():
|
||||||
|
return cls._hash_dir(model_location)
|
||||||
|
else:
|
||||||
|
# avoid circular import
|
||||||
|
from .models import InvalidModelException
|
||||||
|
|
||||||
|
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||||
|
"""
|
||||||
|
Fasthash a single file and return its hexdigest.
|
||||||
|
|
||||||
|
:param model_location: Path to the model file
|
||||||
|
"""
|
||||||
|
# we return sha256 hash of the filehash in order to be
|
||||||
|
# consistent with length of hashes returned by _hash_dir()
|
||||||
|
return hashlib.sha256(hashfile(model_location)).hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||||
|
components: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(model_location):
|
||||||
|
for file in files:
|
||||||
|
# Only pay attention to the big files. The config
|
||||||
|
# files contain things like diffusers point version
|
||||||
|
# which change locally.
|
||||||
|
path = Path(root) / file
|
||||||
|
if path.stat().st_size < cls.MINIMUM_FILE_SIZE:
|
||||||
|
continue
|
||||||
|
fast_hash = cls._hash_file(path)
|
||||||
|
components.update({str(path): fast_hash})
|
||||||
|
|
||||||
|
# hash all the model hashes together, using alphabetic file order
|
||||||
|
sha = hashlib.sha256()
|
||||||
|
for path, fast_hash in sorted(components.items()):
|
||||||
|
sha.update(fast_hash.encode("utf-8"))
|
||||||
|
return sha.hexdigest()
|
@ -260,6 +260,7 @@ from .models import (
|
|||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
|
from .model_hash import FastModelHash
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
@ -364,6 +365,8 @@ class ModelManager(object):
|
|||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
# alias for config file
|
# alias for config file
|
||||||
model_config["model_format"] = model_config.pop("format")
|
model_config["model_format"] = model_config.pop("format")
|
||||||
|
if not model_config.get("hash"):
|
||||||
|
model_config["hash"] = FastModelHash.hash(self.resolve_model_path(model_config["path"]))
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
@ -431,6 +434,28 @@ class ModelManager(object):
|
|||||||
with open(config_path, "w") as yaml_file:
|
with open(config_path, "w") as yaml_file:
|
||||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||||
|
|
||||||
|
def get_model_by_hash(
|
||||||
|
self,
|
||||||
|
model_hash: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Given a model's unique hash, return its ModelInfo.
|
||||||
|
|
||||||
|
:param model_hash: Unique hash for this model.
|
||||||
|
"""
|
||||||
|
info = self.list_models()
|
||||||
|
keys = [x for x in info if x["hash"] == model_hash]
|
||||||
|
if len(keys) == 0:
|
||||||
|
raise InvalidModelException(f"No model with hash {model_hash} found")
|
||||||
|
if len(keys) > 1:
|
||||||
|
raise DuplicateModelException(f"Duplicate models detected: {keys}")
|
||||||
|
return self.get_model(
|
||||||
|
keys[0]["model_name"],
|
||||||
|
base_model=keys[0]["base_model"],
|
||||||
|
model_type=keys[0]["model_type"],
|
||||||
|
)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -500,14 +525,12 @@ class ModelManager(object):
|
|||||||
self.cache_keys[model_key] = set()
|
self.cache_keys[model_key] = set()
|
||||||
self.cache_keys[model_key].add(model_context.key)
|
self.cache_keys[model_key].add(model_context.key)
|
||||||
|
|
||||||
model_hash = "<NO_HASH>" # TODO:
|
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
context=model_context,
|
context=model_context,
|
||||||
name=model_name,
|
name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
type=submodel_type or model_type,
|
type=submodel_type or model_type,
|
||||||
hash=model_hash,
|
hash=model_config.hash,
|
||||||
location=model_path, # TODO:
|
location=model_path, # TODO:
|
||||||
precision=self.cache.precision,
|
precision=self.cache.precision,
|
||||||
_cache=self.cache,
|
_cache=self.cache,
|
||||||
@ -660,12 +683,22 @@ class ModelManager(object):
|
|||||||
if path := model_attributes.get("path"):
|
if path := model_attributes.get("path"):
|
||||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
||||||
|
|
||||||
|
if not model_attributes.get("hash"):
|
||||||
|
hash = FastModelHash.hash(self.resolve_model_path(model_attributes["path"]))
|
||||||
|
model_attributes["hash"] = hash
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if model_key in self.models and not clobber:
|
if not clobber:
|
||||||
|
if model_key in self.models:
|
||||||
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
||||||
|
try:
|
||||||
|
i = self.get_model_by_hash(model_attributes["hash"])
|
||||||
|
raise DuplicateModelException(f"There is already a model with hash {hash}: {i['name']}")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
old_model = self.models.pop(model_key, None)
|
old_model = self.models.pop(model_key, None)
|
||||||
if old_model is not None:
|
if old_model is not None:
|
||||||
@ -941,7 +974,11 @@ class ModelManager(object):
|
|||||||
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
model_path = self.relative_model_path(model_path)
|
model_path = self.relative_model_path(model_path)
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
model_config: ModelConfigBase = model_class.probe_config(
|
||||||
|
str(model_path),
|
||||||
|
hash=FastModelHash.hash(model_path),
|
||||||
|
model_base=cur_base_model,
|
||||||
|
)
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
new_models_found = True
|
new_models_found = True
|
||||||
except DuplicateModelException as e:
|
except DuplicateModelException as e:
|
||||||
|
@ -345,8 +345,12 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif lora_token_vector_length == 1024:
|
elif lora_token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif lora_token_vector_length is None: # variant w/o the text encoder!
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
else:
|
else:
|
||||||
raise InvalidModelException(f"Unknown LoRA type")
|
raise InvalidModelException(
|
||||||
|
f"Unknown LoRA type: {self.checkpoint_path}, lora_token_vector_length={lora_token_vector_length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
@ -89,6 +89,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
path: str # or Path
|
path: str # or Path
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
model_format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
|
hash: Optional[str] = Field(None)
|
||||||
error: Optional[ModelError] = Field(None)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -197,15 +198,16 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||||
if "model_format" not in kwargs:
|
if "model_format" not in kwargs:
|
||||||
raise Exception("Field 'model_format' not found in model config")
|
raise Exception("Field 'model_format' not found in model config")
|
||||||
|
|
||||||
configs = cls._get_configs()
|
configs = cls._get_configs()
|
||||||
return configs[kwargs["model_format"]](**kwargs)
|
config = configs[kwargs["model_format"]](**kwargs)
|
||||||
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
model_format=cls.detect_format(path),
|
model_format=cls.detect_format(path),
|
||||||
|
hash=kwargs["hash"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -13,8 +13,11 @@ from .base import (
|
|||||||
read_checkpoint_meta,
|
read_checkpoint_meta,
|
||||||
classproperty,
|
classproperty,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLModelFormat(str, Enum):
|
class StableDiffusionXLModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
@ -22,7 +25,7 @@ class StableDiffusionXLModelFormat(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLModel(DiffusersModel):
|
class StableDiffusionXLModel(DiffusersModel):
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwritten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
@ -79,14 +82,19 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
raise Exception("Unkown stable diffusion 2.* model format")
|
raise Exception("Unkown stable diffusion 2.* model format")
|
||||||
|
|
||||||
if ckpt_config_path is None:
|
if ckpt_config_path is None and "model_base" in kwargs:
|
||||||
# TO DO: implement picking
|
ckpt_config_path = (
|
||||||
pass
|
app_config.legacy_conf_path / "sd_xl_base.yaml"
|
||||||
|
if kwargs["model_base"] == BaseModelType.StableDiffusionXL
|
||||||
|
else app_config.legacy_conf_path / "sd_xl_refiner.yaml"
|
||||||
|
if kwargs["model_base"] == BaseModelType.StableDiffusionXLRefiner
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
config=ckpt_config_path,
|
config=str(ckpt_config_path),
|
||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ dependencies = [
|
|||||||
"flask_socketio==5.3.0",
|
"flask_socketio==5.3.0",
|
||||||
"flaskwebgui==1.0.3",
|
"flaskwebgui==1.0.3",
|
||||||
"huggingface-hub>=0.11.1",
|
"huggingface-hub>=0.11.1",
|
||||||
|
"imohash~=1.0.0",
|
||||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||||
|
Reference in New Issue
Block a user