Merge branch 'main' into refactor/model-manager-2

This commit is contained in:
psychedelicious
2023-11-14 07:51:57 +11:00
committed by GitHub
253 changed files with 1712 additions and 3981 deletions

View File

@ -88,7 +88,7 @@ class PromptFormatter:
t2i = self.t2i
opt = self.opt
switches = list()
switches = []
switches.append(f'"{opt.prompt}"')
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")

View File

@ -88,7 +88,7 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
if type(image) is str:
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image)

View File

@ -40,7 +40,7 @@ class InitImageResizer:
(rw, rh) = (int(scale * im.width), int(scale * im.height))
# round everything to multiples of 64
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
# no resize necessary, but return a copy
if im.width == width and im.height == height:

View File

@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models():
target_dir = config.models_path / "core/convert"
kwargs = dict() # for future use
kwargs = {} # for future use
try:
logger.info("Downloading core tokenizers and text encoders")
@ -252,26 +252,26 @@ def download_conversion_models():
def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description="RealESRGAN_x4plus.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description="RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description="RealESRGAN_x2plus.pth",
),
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"description": "RealESRGAN_x4plus.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"description": "RealESRGAN_x4plus_anime_6B.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
"description": "RealESRGAN_x2plus.pth",
},
]
for model in URLs:
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all
else list(),
else [],
)

View File

@ -123,8 +123,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
@ -143,8 +141,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(self):
"""
@ -182,10 +178,10 @@ class MigrateTo3(object):
"""
dest_directory = self.dest_models
kwargs = dict(
cache_dir=self.root_directory / "models/hub",
kwargs = {
"cache_dir": self.root_directory / "models/hub",
# local_files_only = True
)
}
try:
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
@ -316,11 +312,11 @@ class MigrateTo3(object):
dest_dir = self.dest_models
cache = self.root_directory / "models/hub"
kwargs = dict(
cache_dir=cache,
safety_checker=None,
kwargs = {
"cache_dir": cache,
"safety_checker": None,
# local_files_only = True,
)
}
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name

View File

@ -120,7 +120,7 @@ class ModelInstall(object):
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = dict()
model_dict = {}
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
@ -134,7 +134,7 @@ class ModelInstall(object):
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = [x for x in self.mgr.list_models()]
installed_models = list(self.mgr.list_models())
for md in installed_models:
base = md["base_model"]
@ -176,7 +176,7 @@ class ModelInstall(object):
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, value in self.datasets.items():
for key, _value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
@ -184,7 +184,7 @@ class ModelInstall(object):
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get("recommended", False)])
return {x for x in starters if self.datasets[x].get("recommended", False)}
def default_model(self) -> str:
starters = self.starter_models()
@ -234,7 +234,7 @@ class ModelInstall(object):
"""
if not models_installed:
models_installed = dict()
models_installed = {}
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
@ -252,16 +252,14 @@ class ModelInstall(object):
# folders style or similar
elif path.is_dir() and any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
]
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -433,17 +431,17 @@ class ModelInstall(object):
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict(
path=str(rel_path),
description=str(description),
model_format=info.format,
)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
dict(
variant=info.variant_type,
)
{
"variant": info.variant_type,
}
)
if info.format == "checkpoint":
try:
@ -474,7 +472,7 @@ class ModelInstall(object):
)
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))
attributes.update({"config": str(legacy_conf)})
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
@ -519,7 +517,7 @@ class ModelInstall(object):
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = list()
paths = []
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(

View File

@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.

View File

@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
b, L, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
return self.to_out(out)

View File

@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
for _i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)

View File

@ -104,7 +104,7 @@ class ModelPatcher:
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
original_weights = dict()
original_weights = {}
try:
with torch.no_grad():
for lora, lora_weight in loras:
@ -242,7 +242,7 @@ class ModelPatcher:
):
skipped_layers = []
try:
for i in range(clip_skip):
for _i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict()
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
@ -385,10 +385,10 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = dict()
orig_weights = {}
try:
blended_loras = dict()
blended_loras = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
@ -404,7 +404,7 @@ class ONNXModelPatcher:
else:
blended_loras[layer_key] = layer_weight
node_names = dict()
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name

View File

@ -66,11 +66,13 @@ class CacheStats(object):
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@ -132,7 +134,7 @@ class ModelCache(object):
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
@ -147,8 +149,8 @@ class ModelCache(object):
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
self._cached_models = {}
self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:

View File

@ -26,5 +26,5 @@ def skip_torch_weight_init():
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
torch_module.reset_parameters = saved_function

View File

@ -363,7 +363,7 @@ class ModelManager(object):
else:
return
self.models = dict()
self.models = {}
for model_key, model_config in config.items():
if model_key.startswith("_"):
continue
@ -374,7 +374,7 @@ class ModelManager(object):
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict()
self.cache_keys = {}
# add controlnet, lora and textual_inversion models from disk
self.scan_models_directory()
@ -655,7 +655,7 @@ class ModelManager(object):
"""
# TODO: redo
for model_dict in self.list_models():
for model_name, model_info in model_dict.items():
for _model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)
@ -902,7 +902,7 @@ class ModelManager(object):
"""
Write current configuration out to the indicated file.
"""
data_to_save = dict()
data_to_save = {}
data_to_save["__metadata__"] = self.config_meta.model_dump()
for model_key, model_config in self.models.items():
@ -1034,7 +1034,7 @@ class ModelManager(object):
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
self.new_models_found = {}
def on_model_found(self, model: Path):
if model not in self.ignore:
@ -1106,7 +1106,7 @@ class ModelManager(object):
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
successfully_installed = {}
installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self

View File

@ -92,7 +92,7 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
model_paths = []
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
@ -124,13 +124,13 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=dump_path,
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
attributes = {
"path": dump_path,
"description": f"Merge of models {', '.join(model_names)}",
"model_format": "diffusers",
"variant": ModelVariantType.Normal.value,
"vae": vae,
}
return self.manager.add_model(
merged_model_name,
base_model=base_model,

View File

@ -237,7 +237,7 @@ class ModelProbe(object):
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3

View File

@ -59,7 +59,7 @@ class ModelSearch(ABC):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue
self._items_scanned += len(dirs) + len(files)
@ -69,16 +69,14 @@ class ModelSearch(ABC):
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
]
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
):
try:
self.on_model_found(path)

View File

@ -97,8 +97,8 @@ MODEL_CLASSES = {
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel):
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
model_config = ConfigDict(protected_namespaces=())
for base_model, models in MODEL_CLASSES.items():
for _base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
def get_model_config_enums():
enums = list()
enums = []
for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"):

View File

@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
res_type = res_type.pipelines
for subtype in subtypes:
res_type = getattr(res_type, subtype)
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
with suppress(Exception):
return cls.__configs
configs = dict()
configs = {}
for name in dir(cls):
if name.startswith("__"):
continue
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = {}
try:
config_data = DiffusionPipeline.load_config(self.model_path)
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
def _fast_safetensors_reader(path: str):
checkpoint = dict()
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = dict()
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
class _access_helper:
def __init__(self, raw_proto):
self.indexes = dict()
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
return self.indexes.keys()
def values(self):
return [obj for obj in self.raw_proto]
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path

View File

@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
return ControlNetModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = dict()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
model = cls(
name=file_path.stem, # TODO:
layers=dict(),
layers={},
)
if file_path.suffix == ".safetensors":
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
state_dict_groupped = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")

View File

@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
return None
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -89,7 +89,7 @@ class VaeModel(ModelBase):
return VaeModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]

View File

@ -54,13 +54,13 @@ class Context:
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -170,7 +170,7 @@ class Context:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
@ -445,7 +445,7 @@ def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)

View File

@ -56,7 +56,7 @@ class AttentionMapSaver:
merged = None
for key, maps in self.collated_maps.items():
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))

View File

@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent:
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
for _i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
]
mid_block_res_sample += mid_sample
@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent:
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
weighted_cond_list = (
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]

View File

@ -16,28 +16,28 @@ from diffusers import (
UniPCMultistepScheduler,
)
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")),
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True)),
lcm=(LCMScheduler, dict()),
)
SCHEDULER_MAP = {
"ddim": (DDIMScheduler, {}),
"ddpm": (DDPMScheduler, {}),
"deis": (DEISMultistepScheduler, {}),
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
"pndm": (PNDMScheduler, {}),
"heun": (HeunDiscreteScheduler, {"use_karras_sigmas": False}),
"heun_k": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
"euler_a": (EulerAncestralDiscreteScheduler, {}),
"kdpm_2": (KDPM2DiscreteScheduler, {}),
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
"lcm": (LCMScheduler, {}),
}

View File

@ -615,7 +615,7 @@ def do_textual_inversion_training(
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
pipeline_args = dict(local_files_only=True)
pipeline_args = {"local_files_only": True}
if tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
else:

View File

@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
for down_block_res_sample, controlnet_block in zip(
down_block_res_samples, self.controlnet_down_blocks, strict=True
):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
down_block_res_samples = [
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=False)
]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]

View File

@ -225,34 +225,34 @@ def basicConfig(**kwargs):
_FACILITY_MAP = (
dict(
LOG_KERN=syslog.LOG_KERN,
LOG_USER=syslog.LOG_USER,
LOG_MAIL=syslog.LOG_MAIL,
LOG_DAEMON=syslog.LOG_DAEMON,
LOG_AUTH=syslog.LOG_AUTH,
LOG_LPR=syslog.LOG_LPR,
LOG_NEWS=syslog.LOG_NEWS,
LOG_UUCP=syslog.LOG_UUCP,
LOG_CRON=syslog.LOG_CRON,
LOG_SYSLOG=syslog.LOG_SYSLOG,
LOG_LOCAL0=syslog.LOG_LOCAL0,
LOG_LOCAL1=syslog.LOG_LOCAL1,
LOG_LOCAL2=syslog.LOG_LOCAL2,
LOG_LOCAL3=syslog.LOG_LOCAL3,
LOG_LOCAL4=syslog.LOG_LOCAL4,
LOG_LOCAL5=syslog.LOG_LOCAL5,
LOG_LOCAL6=syslog.LOG_LOCAL6,
LOG_LOCAL7=syslog.LOG_LOCAL7,
)
{
"LOG_KERN": syslog.LOG_KERN,
"LOG_USER": syslog.LOG_USER,
"LOG_MAIL": syslog.LOG_MAIL,
"LOG_DAEMON": syslog.LOG_DAEMON,
"LOG_AUTH": syslog.LOG_AUTH,
"LOG_LPR": syslog.LOG_LPR,
"LOG_NEWS": syslog.LOG_NEWS,
"LOG_UUCP": syslog.LOG_UUCP,
"LOG_CRON": syslog.LOG_CRON,
"LOG_SYSLOG": syslog.LOG_SYSLOG,
"LOG_LOCAL0": syslog.LOG_LOCAL0,
"LOG_LOCAL1": syslog.LOG_LOCAL1,
"LOG_LOCAL2": syslog.LOG_LOCAL2,
"LOG_LOCAL3": syslog.LOG_LOCAL3,
"LOG_LOCAL4": syslog.LOG_LOCAL4,
"LOG_LOCAL5": syslog.LOG_LOCAL5,
"LOG_LOCAL6": syslog.LOG_LOCAL6,
"LOG_LOCAL7": syslog.LOG_LOCAL7,
}
if SYSLOG_AVAILABLE
else dict()
else {}
)
_SOCK_MAP = dict(
SOCK_STREAM=socket.SOCK_STREAM,
SOCK_DGRAM=socket.SOCK_DGRAM,
)
_SOCK_MAP = {
"SOCK_STREAM": socket.SOCK_STREAM,
"SOCK_DGRAM": socket.SOCK_DGRAM,
}
class InvokeAIFormatter(logging.Formatter):
@ -344,7 +344,7 @@ LOG_FORMATTERS = {
class InvokeAILogger(object):
loggers = dict()
loggers = {}
@classmethod
def get_logger(
@ -364,7 +364,7 @@ class InvokeAILogger(object):
@classmethod
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
handlers = list()
handlers = []
for handler in handler_strs:
handler_name, *args = handler.split("=", 2)
args = args[0] if len(args) > 0 else None
@ -398,7 +398,7 @@ class InvokeAILogger(object):
raise ValueError("syslog is not available on this system")
if not args:
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
syslog_args = dict()
syslog_args = {}
try:
for a in args.split(","):
arg_name, *arg_value = a.split(":", 2)
@ -434,7 +434,7 @@ class InvokeAILogger(object):
path = url.path
port = url.port or 80
syslog_args = dict()
syslog_args = {}
for a in arg_list:
arg_name, *arg_value = a.split(":", 2)
if arg_name == "method":

View File

@ -29,7 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
txts = []
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
@ -93,7 +93,7 @@ def instantiate_from_config(config, **kwargs):
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
def get_obj_from_str(string, reload=False):
@ -231,11 +231,12 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
angles = 2 * math.pi * rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
def tile_grads(slice1, slice2):
return (
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
def dot(grad, shift):
return (