chore: ruff check - fix flake8-comprensions

This commit is contained in:
psychedelicious
2023-11-11 10:44:43 +11:00
parent 43f2398e14
commit 3a136420d5
60 changed files with 489 additions and 512 deletions

View File

@ -104,7 +104,7 @@ class ModelPatcher:
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
original_weights = dict()
original_weights = {}
try:
with torch.no_grad():
for lora, lora_weight in loras:
@ -324,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict()
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
@ -385,10 +385,10 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = dict()
orig_weights = {}
try:
blended_loras = dict()
blended_loras = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
@ -404,7 +404,7 @@ class ONNXModelPatcher:
else:
blended_loras[layer_key] = layer_weight
node_names = dict()
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name

View File

@ -132,7 +132,7 @@ class ModelCache(object):
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
@ -147,8 +147,8 @@ class ModelCache(object):
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
self._cached_models = {}
self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:

View File

@ -363,7 +363,7 @@ class ModelManager(object):
else:
return
self.models = dict()
self.models = {}
for model_key, model_config in config.items():
if model_key.startswith("_"):
continue
@ -374,7 +374,7 @@ class ModelManager(object):
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict()
self.cache_keys = {}
# add controlnet, lora and textual_inversion models from disk
self.scan_models_directory()
@ -902,7 +902,7 @@ class ModelManager(object):
"""
Write current configuration out to the indicated file.
"""
data_to_save = dict()
data_to_save = {}
data_to_save["__metadata__"] = self.config_meta.model_dump()
for model_key, model_config in self.models.items():
@ -1034,7 +1034,7 @@ class ModelManager(object):
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
self.new_models_found = {}
def on_model_found(self, model: Path):
if model not in self.ignore:
@ -1106,7 +1106,7 @@ class ModelManager(object):
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
successfully_installed = {}
installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self

View File

@ -92,7 +92,7 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
model_paths = []
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
@ -124,13 +124,13 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=dump_path,
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
attributes = {
"path": dump_path,
"description": f"Merge of models {', '.join(model_names)}",
"model_format": "diffusers",
"variant": ModelVariantType.Normal.value,
"vae": vae,
}
return self.manager.add_model(
merged_model_name,
base_model=base_model,

View File

@ -59,7 +59,7 @@ class ModelSearch(ABC):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue
self._items_scanned += len(dirs) + len(files)
@ -69,8 +69,7 @@ class ModelSearch(ABC):
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
(path / x).exists()
for x in {
"config.json",
"model_index.json",
@ -78,7 +77,6 @@ class ModelSearch(ABC):
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
]
):
try:
self.on_model_found(path)

View File

@ -97,8 +97,8 @@ MODEL_CLASSES = {
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel):
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
def get_model_config_enums():
enums = list()
enums = []
for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"):

View File

@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
with suppress(Exception):
return cls.__configs
configs = dict()
configs = {}
for name in dir(cls):
if name.startswith("__"):
continue
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = {}
try:
config_data = DiffusionPipeline.load_config(self.model_path)
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
def _fast_safetensors_reader(path: str):
checkpoint = dict()
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = dict()
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
class _access_helper:
def __init__(self, raw_proto):
self.indexes = dict()
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
return self.indexes.keys()
def values(self):
return [obj for obj in self.raw_proto]
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path

View File

@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
return ControlNetModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -73,7 +73,7 @@ class LoRAModel(ModelBase):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@ -499,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = dict()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
@ -545,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
model = cls(
name=file_path.stem, # TODO:
layers=dict(),
layers={},
)
if file_path.suffix == ".safetensors":
@ -593,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
state_dict_groupped = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")

View File

@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
return None
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -89,7 +89,7 @@ class VaeModel(ModelBase):
return VaeModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")