mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff check - fix flake8-bugbear
This commit is contained in:
parent
3a136420d5
commit
99a8ebe3a0
@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [browse_start],
|
||||
get_paths=lambda: [browse_start], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
|
@ -210,7 +210,7 @@ def generate_face_box_mask(
|
||||
# Check if any face is detected.
|
||||
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
||||
# Search for the face_id in the detected faces.
|
||||
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||
# Get the bounding box of the face mesh.
|
||||
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
||||
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
||||
|
@ -1105,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise "Latents to blend must be the same size."
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
class InvocationProcessorABC(ABC): # noqa: B024
|
||||
pass
|
||||
|
@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
def log_stats(self):
|
||||
completed = set()
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
for graph_id, _node_log in self._stats.items():
|
||||
try:
|
||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
|
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
count = 0
|
||||
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip)))
|
||||
data.append(list(zip(*to_zip, strict=True)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
@ -1119,7 +1119,7 @@ class GraphExecutionState(BaseModel):
|
||||
for edge in input_edges
|
||||
if edge.destination.field == "item"
|
||||
]
|
||||
setattr(node, "collection", output_collection)
|
||||
node.collection = output_collection
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
|
@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
|
||||
|
||||
def lvmin_thin(x, prunings=True):
|
||||
y = x
|
||||
for i in range(32):
|
||||
for _i in range(32):
|
||||
y, is_done = thin_one_time(y, lvmin_kernels)
|
||||
if is_done:
|
||||
break
|
||||
|
@ -123,8 +123,6 @@ class MigrateTo3(object):
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
@ -143,8 +141,6 @@ class MigrateTo3(object):
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def migrate_support_models(self):
|
||||
"""
|
||||
|
@ -176,7 +176,7 @@ class ModelInstall(object):
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
models = set()
|
||||
for key, value in self.datasets.items():
|
||||
for key, _value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||
models.add(key)
|
||||
|
@ -130,7 +130,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||
|
||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
|
@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
for _i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
@ -242,7 +242,7 @@ class ModelPatcher:
|
||||
):
|
||||
skipped_layers = []
|
||||
try:
|
||||
for i in range(clip_skip):
|
||||
for _i in range(clip_skip):
|
||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
||||
|
||||
yield
|
||||
|
@ -26,5 +26,5 @@ def skip_torch_weight_init():
|
||||
|
||||
yield None
|
||||
finally:
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions):
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||
torch_module.reset_parameters = saved_function
|
||||
|
@ -655,7 +655,7 @@ class ModelManager(object):
|
||||
"""
|
||||
# TODO: redo
|
||||
for model_dict in self.list_models():
|
||||
for model_name, model_info in model_dict.items():
|
||||
for _model_name, model_info in model_dict.items():
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
|
||||
|
@ -237,7 +237,7 @@ class ModelProbe(object):
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
|
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for _base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
|
@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
|
||||
else:
|
||||
res_type = sys.modules["diffusers"]
|
||||
res_type = getattr(res_type, "pipelines")
|
||||
res_type = res_type.pipelines
|
||||
|
||||
for subtype in subtypes:
|
||||
res_type = getattr(res_type, subtype)
|
||||
|
@ -462,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
|
@ -54,13 +54,13 @@ class Context:
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
assert False, f"name {name} cannot appear more than once"
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
@ -170,7 +170,7 @@ class Context:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
@ -445,7 +445,7 @@ def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
for _identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
@ -56,7 +56,7 @@ class AttentionMapSaver:
|
||||
|
||||
merged = None
|
||||
|
||||
for key, maps in self.collated_maps.items():
|
||||
for _key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
|
@ -123,7 +123,7 @@ class InvokeAIDiffuserComponent:
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
for _i, control_datum in enumerate(control_data):
|
||||
control_mode = control_datum.control_mode
|
||||
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||
# that are combined at higher level to make control_mode enum
|
||||
@ -214,7 +214,7 @@ class InvokeAIDiffuserComponent:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
|
@ -732,7 +732,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks, strict=True):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
@ -745,7 +745,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
|
@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||
|
||||
def create(self):
|
||||
self.keypress_timeout = 10
|
||||
@ -203,7 +203,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
|
||||
# This restores the selected page on return from an installation
|
||||
for i in range(1, self.current_tab + 1):
|
||||
for _i in range(1, self.current_tab + 1):
|
||||
self.tabs.h_cursor_line_down(1)
|
||||
self._toggle_tables([self.current_tab])
|
||||
|
||||
@ -258,9 +258,11 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
model_type: ModelType,
|
||||
window_width: int = 120,
|
||||
install_prompt: str = None,
|
||||
exclude: set = set(),
|
||||
exclude: set = None,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
widgets = {}
|
||||
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
@ -366,13 +368,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
]
|
||||
|
||||
for group in widgets:
|
||||
for k, v in group.items():
|
||||
for _k, v in group.items():
|
||||
try:
|
||||
v.hidden = True
|
||||
v.editable = False
|
||||
except Exception:
|
||||
pass
|
||||
for k, v in widgets[selected_tab].items():
|
||||
for _k, v in widgets[selected_tab].items():
|
||||
try:
|
||||
v.hidden = False
|
||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||
|
@ -11,6 +11,7 @@ import sys
|
||||
import textwrap
|
||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||
from shutil import get_terminal_size
|
||||
from typing import Optional
|
||||
|
||||
import npyscreen
|
||||
import npyscreen.wgmultiline as wgmultiline
|
||||
@ -243,7 +244,9 @@ class SelectColumnBase:
|
||||
|
||||
|
||||
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
@ -267,7 +270,9 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
def __init__(self, screen, columns: int = 1, values: list = [], **keywords):
|
||||
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
|
@ -6,5 +6,5 @@ import warnings
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure as configure
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning)
|
||||
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2)
|
||||
configure()
|
||||
|
@ -471,7 +471,6 @@ def test_graph_gets_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
n1.graph.add_node
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
@ -544,7 +543,6 @@ def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
n1.graph.add_node
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
@ -559,7 +557,6 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
n1.graph.add_node
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
|
Loading…
Reference in New Issue
Block a user