mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/refactor_generation_backend
This commit is contained in:
@ -10,12 +10,15 @@ import sys
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import psutil
|
||||
import shutil
|
||||
import textwrap
|
||||
import torch
|
||||
import traceback
|
||||
import yaml
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
@ -44,6 +47,8 @@ from invokeai.app.services.config import (
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
@ -53,6 +58,7 @@ from invokeai.frontend.install.widgets import (
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
WindowTooSmallException,
|
||||
)
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
@ -61,6 +67,7 @@ from invokeai.backend.install.model_install_backend import (
|
||||
ModelInstall,
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -76,6 +83,13 @@ Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
@ -86,6 +100,12 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
if not any(errors):
|
||||
@ -378,13 +398,35 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
self.max_cache_size = self.add_widget_intelligent(
|
||||
IntTitleSlider,
|
||||
name="Size of the RAM cache used for fast model switching (GB)",
|
||||
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
value=old_opts.max_cache_size,
|
||||
out_of=20,
|
||||
out_of=MAX_RAM,
|
||||
lowest=3,
|
||||
begin_entry_at=6,
|
||||
scroll_exit=True,
|
||||
)
|
||||
if HAS_CUDA:
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.max_vram_cache_size = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=old_opts.max_vram_cache_size,
|
||||
out_of=round(MAX_VRAM * 2) / 2,
|
||||
lowest=0.0,
|
||||
relx=8,
|
||||
step=0.25,
|
||||
scroll_exit=True,
|
||||
)
|
||||
else:
|
||||
self.max_vram_cache_size = DummyWidgetValue.zero
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@ -401,7 +443,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
@ -476,6 +518,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
"outdir",
|
||||
"free_gpu_mem",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
]:
|
||||
@ -592,13 +635,13 @@ def maybe_create_models_yaml(root: Path):
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
# The third argument is needed in the Windows 11 environment to
|
||||
# launch a console window running this program.
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
@ -654,10 +697,13 @@ def migrate_init_file(legacy_format: Path):
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
|
||||
for attr in fields:
|
||||
if hasattr(old, attr):
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
try:
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
except ValidationError as e:
|
||||
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
@ -777,6 +823,7 @@ def main():
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
@ -802,6 +849,8 @@ def main():
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
|
@ -101,9 +101,9 @@ class ModelInstall(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
model_manager: ModelManager = None,
|
||||
access_token: str = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
model_manager: Optional[ModelManager] = None,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
|
@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import os
|
||||
import textwrap
|
||||
import yaml
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree, move
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -259,6 +259,7 @@ from .models import (
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
ModelBase,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -361,7 +362,7 @@ class ModelManager(object):
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
@ -381,18 +382,24 @@ class ModelManager(object):
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
Given a model name, returns True if it is a valid identifier.
|
||||
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param rescan: if True, scan_models_directory
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
return model_key in self.models
|
||||
exists = model_key in self.models
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if rescan and not exists:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
|
||||
|
||||
return exists
|
||||
|
||||
@classmethod
|
||||
def create_key(
|
||||
@ -443,39 +450,32 @@ class ModelManager(object):
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if model_key not in self.models:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
if model_key not in self.models:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.models[model_key]
|
||||
model_path = self.resolve_model_path(model_config.path)
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
|
||||
if not model_path.exists():
|
||||
if model_class.save_to_config:
|
||||
self.models[model_key].error = ModelError.NotFound
|
||||
raise Exception(f'Files for model "{model_key}" not found')
|
||||
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = self.resolve_path(override_path)
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
@ -513,12 +513,61 @@ class ModelManager(object):
|
||||
_cache=self.cache,
|
||||
)
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = model_config.path
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||
"""Get a model's config object."""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
try:
|
||||
model_config = self.models[model_key]
|
||||
except KeyError:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
return model_config
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _instantiate(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelBase:
|
||||
"""Make a new instance of this model, without loading it."""
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
# FIXME: do non-overriden submodels get the right class?
|
||||
constructor = self._get_implementation(base_model, model_type)
|
||||
instance = constructor(model_path, base_model, model_type)
|
||||
return instance
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
"""
|
||||
@ -540,13 +589,16 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model, model_type, model_name)
|
||||
return models[0] if models else None
|
||||
if len(models) >= 1:
|
||||
return models[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
@ -560,7 +612,7 @@ class ModelManager(object):
|
||||
|
||||
model_keys = (
|
||||
[self.create_key(model_name, base_model, model_type)]
|
||||
if model_name
|
||||
if model_name and base_model and model_type
|
||||
else sorted(self.models, key=str.casefold)
|
||||
)
|
||||
models = []
|
||||
@ -596,7 +648,7 @@ class ModelManager(object):
|
||||
Print a table of models and their descriptions. This needs to be redone
|
||||
"""
|
||||
# TODO: redo
|
||||
for model_type, model_dict in self.list_models().items():
|
||||
for model_dict in self.list_models():
|
||||
for model_name, model_info in model_dict.items():
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
@ -658,7 +710,7 @@ class ModelManager(object):
|
||||
if path := model_attributes.get("path"):
|
||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
@ -699,8 +751,8 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename or rebase a model.
|
||||
@ -753,7 +805,7 @@ class ModelManager(object):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
@ -767,6 +819,10 @@ class ModelManager(object):
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
"""
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
|
||||
if info is None:
|
||||
raise FileNotFoundError(f"model not found: {model_name}")
|
||||
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
@ -836,7 +892,7 @@ class ModelManager(object):
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
@ -845,7 +901,7 @@ class ModelManager(object):
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
@ -903,7 +959,7 @@ class ModelManager(object):
|
||||
|
||||
model_path = self.resolve_model_path(model_config.path).absolute()
|
||||
if not model_path.exists():
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
@ -919,7 +975,7 @@ class ModelManager(object):
|
||||
for cur_model_type in ModelType:
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
|
||||
if not models_dir.exists():
|
||||
@ -935,7 +991,9 @@ class ModelManager(object):
|
||||
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
||||
|
||||
model_path = self.relative_model_path(model_path)
|
||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||
model_config: ModelConfigBase = model_class.probe_config(
|
||||
str(model_path), model_base=cur_base_model
|
||||
)
|
||||
self.models[model_key] = model_config
|
||||
new_models_found = True
|
||||
except DuplicateModelException as e:
|
||||
@ -983,7 +1041,7 @@ class ModelManager(object):
|
||||
# LS: hacky
|
||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||
try:
|
||||
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
|
||||
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -1011,7 +1069,7 @@ class ModelManager(object):
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
|
@ -33,7 +33,7 @@ class ModelMerger(object):
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
@ -73,7 +73,7 @@ class ModelMerger(object):
|
||||
base_model: Union[BaseModelType, str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
@ -122,7 +122,7 @@ class ModelMerger(object):
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=str(dump_path),
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
|
@ -80,8 +80,10 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if ckpt_config_path is None:
|
||||
# TO DO: implement picking
|
||||
pass
|
||||
# avoid circular import
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
|
||||
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
|
@ -1,9 +1,14 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -18,9 +23,6 @@ from .base import (
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
|
@ -1,18 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import Field
|
||||
|
||||
import math
|
||||
import einops
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import einops
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@ -27,17 +28,18 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from pydantic import Field
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..util import CPU_DEVICE, normalize_device
|
||||
from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from ..util import normalize_device
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -292,9 +294,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -335,12 +335,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return
|
||||
|
||||
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
@ -363,10 +363,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.unet.device
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user