Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -249,20 +249,26 @@ from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
BaseModelType,
ModelType,
SubModelType,
ModelError,
SchedulerPredictionType,
MODEL_CLASSES,
ModelConfigBase,
ModelNotFoundException, InvalidModelException,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
CONFIG_FILE_VERSION = "3.0.0"
@dataclass
class ModelInfo():
class ModelInfo:
context: ModelLocker
name: str
base_model: BaseModelType
@ -275,20 +281,24 @@ class ModelInfo():
def __enter__(self):
return self.context.__enter__()
def __exit__(self,*args, **kwargs):
def __exit__(self, *args, **kwargs):
self.context.__exit__(*args, **kwargs)
class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
MAX_CACHE_SIZE = 6.0 # GB
class ConfigMeta(BaseModel):
version: str
class ModelManager(object):
"""
High-level interface to model management.
@ -315,12 +325,12 @@ class ModelManager(object):
if isinstance(config, (str, Path)):
self.config_path = Path(config)
if not self.config_path.exists():
logger.warning(f'The file {self.config_path} was not found. Initializing a new file')
logger.warning(f"The file {self.config_path} was not found. Initializing a new file")
self.initialize_model_config(self.config_path)
config = OmegaConf.load(self.config_path)
elif not isinstance(config, DictConfig):
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
raise ValueError("config argument must be an OmegaConf object, a Path or a string")
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
# TODO: metadata not found
@ -330,11 +340,11 @@ class ModelManager(object):
self.logger = logger
self.cache = ModelCache(
max_cache_size=max_cache_size,
max_vram_cache_size = self.app_config.max_vram_cache_size,
execution_device = device_type,
precision = precision,
sequential_offload = sequential_offload,
logger = logger,
max_vram_cache_size=self.app_config.max_vram_cache_size,
execution_device=device_type,
precision=precision,
sequential_offload=sequential_offload,
logger=logger,
)
self._read_models(config)
@ -348,7 +358,7 @@ class ModelManager(object):
self.models = dict()
for model_key, model_config in config.items():
if model_key.startswith('_'):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
@ -395,7 +405,7 @@ class ModelManager(object):
@classmethod
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
base_model_str, model_type_str, model_name = model_key.split('/', 2)
base_model_str, model_type_str, model_name = model_key.split("/", 2)
try:
model_type = ModelType(model_type_str)
except:
@ -414,20 +424,16 @@ class ModelManager(object):
@classmethod
def initialize_model_config(cls, config_path: Path):
"""Create empty config file"""
with open(config_path,'w') as yaml_file:
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
)
)
with open(config_path, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None
)->ModelInfo:
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""Given a model named identified in models.yaml, return
an ModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml
@ -451,7 +457,7 @@ class ModelManager(object):
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f"Files for model \"{model_key}\" not found")
raise Exception(f'Files for model "{model_key}" not found')
else:
self.models.pop(model_key, None)
@ -473,7 +479,7 @@ class ModelManager(object):
model_path = model_class.convert_if_required(
base_model=base_model,
model_path=str(model_path), # TODO: refactor str/Path types logic
model_path=str(model_path), # TODO: refactor str/Path types logic
output_path=dst_convert_path,
config=model_config,
)
@ -490,17 +496,17 @@ class ModelManager(object):
self.cache_keys[model_key] = set()
self.cache_keys[model_key].add(model_context.key)
model_hash = "<NO_HASH>" # TODO:
model_hash = "<NO_HASH>" # TODO:
return ModelInfo(
context = model_context,
name = model_name,
base_model = base_model,
type = submodel_type or model_type,
hash = model_hash,
location = model_path, # TODO:
precision = self.cache.precision,
_cache = self.cache,
context=model_context,
name=model_name,
base_model=base_model,
type=submodel_type or model_type,
hash=model_hash,
location=model_path, # TODO:
precision=self.cache.precision,
_cache=self.cache,
)
def model_info(
@ -516,7 +522,7 @@ class ModelManager(object):
if model_key in self.models:
return self.models[model_key].dict(exclude_defaults=True)
else:
return None # TODO: None or empty dict on not found
return None # TODO: None or empty dict on not found
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
@ -526,16 +532,16 @@ class ModelManager(object):
return [(self.parse_key(x)) for x in self.models.keys()]
def list_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model,model_type,model_name)
models = self.list_models(base_model, model_type, model_name)
return models[0] if models else None
def list_models(
@ -548,13 +554,17 @@ class ModelManager(object):
Return a list of models.
"""
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
model_keys = (
[self.create_key(model_name, base_model, model_type)]
if model_name
else sorted(self.models, key=str.casefold)
)
models = []
for model_key in model_keys:
model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise ModelNotFoundException(f'Unknown model {model_name}')
self.logger.error(f"Unknown model {model_name}")
raise ModelNotFoundException(f"Unknown model {model_name}")
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
@ -571,8 +581,8 @@ class ModelManager(object):
)
# expose paths as absolute to help web UI
if path := model_dict.get('path'):
model_dict['path'] = str(self.app_config.root_path / path)
if path := model_dict.get("path"):
model_dict["path"] = str(self.app_config.root_path / path)
models.append(model_dict)
return models
@ -641,15 +651,15 @@ class ModelManager(object):
model_info().
"""
# relativize paths as they go in - this makes it easier to move the root directory around
if path := model_attributes.get('path'):
if path := model_attributes.get("path"):
if Path(path).is_relative_to(self.app_config.root_path):
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models and not clobber:
if model_key in self.models and not clobber:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
old_model = self.models.pop(model_key, None)
@ -675,23 +685,23 @@ class ModelManager(object):
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,
base_model = base_model,
config = model_config,
name=model_name,
model_type=model_type,
base_model=base_model,
config=model_config,
)
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
'''
"""
Rename or rebase a model.
'''
"""
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
@ -710,7 +720,13 @@ class ModelManager(object):
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
new_path = (
self.app_config.root_path
/ "models"
/ BaseModelType(new_base).value
/ ModelType(model_type).value
/ new_name
)
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
@ -726,18 +742,18 @@ class ModelManager(object):
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models.pop(model_key, None) # delete
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
dest_directory: Optional[Path]=None,
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
dest_directory: Optional[Path] = None,
) -> AddModelResult:
'''
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
@ -746,7 +762,7 @@ class ModelManager(object):
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is a checkpoint.
'''
"""
info = self.model_info(model_name, base_model, model_type)
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
@ -754,27 +770,32 @@ class ModelManager(object):
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `location`. It doesn't matter
# what submodeltype we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Scheduler} if model_type==ModelType.Main else {}
model = self.get_model(model_name,
base_model,
model_type,
**submodel,
)
submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {}
model = self.get_model(
model_name,
base_model,
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value
) / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
move(old_diffusers_path, new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
info["path"] = (
str(new_diffusers_path)
if dest_directory
else str(new_diffusers_path.relative_to(self.app_config.root_path))
)
info.pop("config")
result = self.add_model(model_name, base_model, model_type,
model_attributes = info,
clobber=True)
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
except:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
@ -798,15 +819,12 @@ class ModelManager(object):
found_models = []
for file in files:
location = str(file.resolve()).replace("\\", "/")
if (
"model.safetensors" not in location
and "diffusion_pytorch_model.safetensors" not in location
):
if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location:
found_models.append({"name": file.stem, "location": location})
return search_folder, found_models
def commit(self, conf_file: Path=None) -> None:
def commit(self, conf_file: Path = None) -> None:
"""
Write current configuration out to the indicated file.
"""
@ -824,7 +842,7 @@ class ModelManager(object):
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path
assert config_file_path is not None,'no config file path to write to'
assert config_file_path is not None, "no config file path to write to"
config_file_path = self.app_config.root_path / config_file_path
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
try:
@ -857,11 +875,10 @@ class ModelManager(object):
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
):
loaded_files = set()
new_models_found = False
self.logger.info(f'Scanning {self.app_config.models_path} for new models')
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
with Chdir(self.app_config.root_path):
for model_key, model_config in list(self.models.items()):
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -887,10 +904,10 @@ class ModelManager(object):
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
if not models_dir.exists():
continue # TODO: or create all folders?
continue # TODO: or create all folders?
for model_path in models_dir.iterdir():
if model_path not in loaded_files: # TODO: check
if model_path not in loaded_files: # TODO: check
model_name = model_path.name if model_path.is_dir() else model_path.stem
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
@ -900,7 +917,7 @@ class ModelManager(object):
if model_path.is_relative_to(self.app_config.root_path):
model_path = model_path.relative_to(self.app_config.root_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
@ -916,11 +933,10 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path:
self.commit()
def autoimport(self)->Dict[str, AddModelResult]:
'''
def autoimport(self) -> Dict[str, AddModelResult]:
"""
Scan the autoimport directory (if defined) and import new models, delete defunct models.
'''
"""
# avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
@ -939,7 +955,9 @@ class ModelManager(object):
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
self.logger.info(
f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models"
)
def models_found(self):
return self.new_models_found
@ -949,31 +967,37 @@ class ModelManager(object):
# LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI
try:
self.heuristic_import({config.root_path / 'models/core/convert/sd-vae-ft-mse'})
self.heuristic_import({config.root_path / "models/core/convert/sd-vae-ft-mse"})
except:
pass
installer = ModelInstall(config = self.app_config,
model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir,
] if x
}
installer = ModelInstall(
config=self.app_config,
model_manager=self,
prediction_type_helper=ask_user_for_prediction_type,
)
known_paths = {config.root_path / x["path"] for x in self.list_models()}
directories = {
config.root_path / x
for x in [
config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir,
]
if x
}
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
scanner.search()
return scanner.models_found()
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
def heuristic_import(
self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> Dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
@ -992,14 +1016,15 @@ class ModelManager(object):
May return the following exceptions:
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
'''
"""
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper,
model_manager = self)
installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
)
for thing in items_to_import:
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)