chore: ruff check - fix flake8-bugbear

This commit is contained in:
psychedelicious 2023-11-11 10:51:21 +11:00
parent 3a136420d5
commit 99a8ebe3a0
27 changed files with 48 additions and 48 deletions

View File

@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
path_completer = PathCompleter(
only_directories=True,
expanduser=True,
get_paths=lambda: [browse_start],
get_paths=lambda: [browse_start], # noqa: B023
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
)

View File

@ -210,7 +210,7 @@ def generate_face_box_mask(
# Check if any face is detected.
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
# Search for the face_id in the detected faces.
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
# Get the bounding box of the face mesh.
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]

View File

@ -1105,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation):
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise "Latents to blend must be the same size."
raise Exception("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()

View File

@ -1,5 +1,5 @@
from abc import ABC
class InvocationProcessorABC(ABC):
class InvocationProcessorABC(ABC): # noqa: B024
pass

View File

@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
completed = set()
errored = set()
for graph_id, node_log in self._stats.items():
for graph_id, _node_log in self._stats.items():
try:
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception:

View File

@ -355,7 +355,7 @@ def create_session_nfv_tuples(
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# create generator to yield session,nfv tuples
count = 0
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip)))
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs

View File

@ -1119,7 +1119,7 @@ class GraphExecutionState(BaseModel):
for edge in input_edges
if edge.destination.field == "item"
]
setattr(node, "collection", output_collection)
node.collection = output_collection
else:
for edge in input_edges:
output_value = getattr(self.results[edge.source.node_id], edge.source.field)

View File

@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
for _i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break

View File

@ -123,8 +123,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
@ -143,8 +141,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(self):
"""

View File

@ -176,7 +176,7 @@ class ModelInstall(object):
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, value in self.datasets.items():
for key, _value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)

View File

@ -130,7 +130,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.

View File

@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
for _i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)

View File

@ -242,7 +242,7 @@ class ModelPatcher:
):
skipped_layers = []
try:
for i in range(clip_skip):
for _i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield

View File

@ -26,5 +26,5 @@ def skip_torch_weight_init():
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
torch_module.reset_parameters = saved_function

View File

@ -655,7 +655,7 @@ class ModelManager(object):
"""
# TODO: redo
for model_dict in self.list_models():
for model_name, model_info in model_dict.items():
for _model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)

View File

@ -237,7 +237,7 @@ class ModelProbe(object):
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3

View File

@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
model_config = ConfigDict(protected_namespaces=())
for base_model, models in MODEL_CLASSES.items():
for _base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)

View File

@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
res_type = res_type.pipelines
for subtype in subtypes:
res_type = getattr(res_type, subtype)

View File

@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:

View File

@ -54,13 +54,13 @@ class Context:
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -170,7 +170,7 @@ class Context:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
@ -445,7 +445,7 @@ def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)

View File

@ -56,7 +56,7 @@ class AttentionMapSaver:
merged = None
for key, maps in self.collated_maps.items():
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))

View File

@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent:
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
for _i, control_datum in enumerate(control_data):
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent:
# add controlnet outputs together if have multiple controlnets
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
]
mid_block_res_sample += mid_sample

View File

@ -732,7 +732,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks, strict=True):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
@ -745,7 +745,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]

View File

@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage
self.subprocess = None
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
def create(self):
self.keypress_timeout = 10
@ -203,7 +203,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
)
# This restores the selected page on return from an installation
for i in range(1, self.current_tab + 1):
for _i in range(1, self.current_tab + 1):
self.tabs.h_cursor_line_down(1)
self._toggle_tables([self.current_tab])
@ -258,9 +258,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
model_type: ModelType,
window_width: int = 120,
install_prompt: str = None,
exclude: set = set(),
exclude: set = None,
) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets"""
if exclude is None:
exclude = set()
widgets = {}
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list]
@ -366,13 +368,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
]
for group in widgets:
for k, v in group.items():
for _k, v in group.items():
try:
v.hidden = True
v.editable = False
except Exception:
pass
for k, v in widgets[selected_tab].items():
for _k, v in widgets[selected_tab].items():
try:
v.hidden = False
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):

View File

@ -11,6 +11,7 @@ import sys
import textwrap
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
from shutil import get_terminal_size
from typing import Optional
import npyscreen
import npyscreen.wgmultiline as wgmultiline
@ -243,7 +244,9 @@ class SelectColumnBase:
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
if values is None:
values = []
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)
@ -267,7 +270,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
"""Row of radio buttons. Spacebar to select."""
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
if values is None:
values = []
self.columns = columns
self.value_cnt = len(values)
self.rows = math.ceil(self.value_cnt / self.columns)

View File

@ -6,5 +6,5 @@ import warnings
from invokeai.frontend.install.invokeai_configure import invokeai_configure as configure
if __name__ == "__main__":
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning)
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2)
configure()

View File

@ -471,7 +471,6 @@ def test_graph_gets_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
@ -544,7 +543,6 @@ def test_graph_fails_to_get_missing_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)
@ -559,7 +557,6 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
g = Graph()
n1 = GraphInvocation(id="1")
n1.graph = Graph()
n1.graph.add_node
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
n1.graph.add_node(n1_1)