mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff check - fix flake8-comprensions
This commit is contained in:
@ -104,7 +104,7 @@ class ModelPatcher:
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = dict()
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = dict()
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
@ -385,10 +385,10 @@ class ONNXModelPatcher:
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = dict()
|
||||
orig_weights = {}
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
blended_loras = {}
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@ -404,7 +404,7 @@ class ONNXModelPatcher:
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = dict()
|
||||
node_names = {}
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
|
@ -132,7 +132,7 @@ class ModelCache(object):
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.model_infos: Dict[str, ModelBase] = {}
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self.precision: torch.dtype = precision
|
||||
@ -147,8 +147,8 @@ class ModelCache(object):
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
self._cached_models = {}
|
||||
self._cache_stack = []
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
|
@ -363,7 +363,7 @@ class ModelManager(object):
|
||||
else:
|
||||
return
|
||||
|
||||
self.models = dict()
|
||||
self.models = {}
|
||||
for model_key, model_config in config.items():
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
@ -374,7 +374,7 @@ class ModelManager(object):
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.cache_keys = dict()
|
||||
self.cache_keys = {}
|
||||
|
||||
# add controlnet, lora and textual_inversion models from disk
|
||||
self.scan_models_directory()
|
||||
@ -902,7 +902,7 @@ class ModelManager(object):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
data_to_save = dict()
|
||||
data_to_save = {}
|
||||
data_to_save["__metadata__"] = self.config_meta.model_dump()
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
@ -1034,7 +1034,7 @@ class ModelManager(object):
|
||||
self.ignore = ignore
|
||||
|
||||
def on_search_started(self):
|
||||
self.new_models_found = dict()
|
||||
self.new_models_found = {}
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if model not in self.ignore:
|
||||
@ -1106,7 +1106,7 @@ class ModelManager(object):
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
|
||||
successfully_installed = dict()
|
||||
successfully_installed = {}
|
||||
|
||||
installer = ModelInstall(
|
||||
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
||||
|
@ -92,7 +92,7 @@ class ModelMerger(object):
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
model_paths = []
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
@ -124,13 +124,13 @@ class ModelMerger(object):
|
||||
dump_path = (dump_path / merged_model_name).as_posix()
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=dump_path,
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
model_format="diffusers",
|
||||
variant=ModelVariantType.Normal.value,
|
||||
vae=vae,
|
||||
)
|
||||
attributes = {
|
||||
"path": dump_path,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"model_format": "diffusers",
|
||||
"variant": ModelVariantType.Normal.value,
|
||||
"vae": vae,
|
||||
}
|
||||
return self.manager.add_model(
|
||||
merged_model_name,
|
||||
base_model=base_model,
|
||||
|
@ -59,7 +59,7 @@ class ModelSearch(ABC):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
@ -69,8 +69,7 @@ class ModelSearch(ABC):
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
@ -78,7 +77,6 @@ class ModelSearch(ABC):
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
]
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
|
@ -97,8 +97,8 @@ MODEL_CLASSES = {
|
||||
# },
|
||||
}
|
||||
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
MODEL_CONFIGS = []
|
||||
OPENAPI_MODEL_CONFIGS = []
|
||||
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
enums = []
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
if hasattr(inspect, "get_annotations"):
|
||||
|
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
with suppress(Exception):
|
||||
return cls.__configs
|
||||
|
||||
configs = dict()
|
||||
configs = {}
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.child_types: Dict[str, Type] = dict()
|
||||
self.child_sizes: Dict[str, int] = dict()
|
||||
self.child_types: Dict[str, Type] = {}
|
||||
self.child_sizes: Dict[str, int] = {}
|
||||
|
||||
try:
|
||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
|
||||
all_files = os.listdir(model_path)
|
||||
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
|
||||
|
||||
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
|
||||
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
|
||||
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
|
||||
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
|
||||
other_files = set(all_files) - fp16_files - bit8_files
|
||||
|
||||
if variant is None:
|
||||
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
checkpoint = {}
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), "little")
|
||||
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = dict()
|
||||
self.indexes = {}
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = dict()
|
||||
self.indexes = {}
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return [obj for obj in self.raw_proto]
|
||||
return list(self.raw_proto)
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
|
@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
|
||||
return ControlNetModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = dict()
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = dict()
|
||||
state_dict_groupped = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(model_path):
|
||||
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {model_path}")
|
||||
|
@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
|
||||
return None # diffusers-ti
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
|
||||
return None
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
@ -89,7 +89,7 @@ class VaeModel(ModelBase):
|
||||
return VaeModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
Reference in New Issue
Block a user