chore: ruff format

This commit is contained in:
psychedelicious 2023-11-11 10:55:06 +11:00
parent 513fceac82
commit 6494e8e551
18 changed files with 80 additions and 54 deletions

View File

@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
path_completer = PathCompleter( path_completer = PathCompleter(
only_directories=True, only_directories=True,
expanduser=True, expanduser=True,
get_paths=lambda: [browse_start], # noqa: B023 get_paths=lambda: [browse_start], # noqa: B023
# get_paths=lambda: [".."].extend(list(browse_start.iterdir())) # get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
) )
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
completer=path_completer, completer=path_completer,
default=str(browse_start) + os.sep, default=str(browse_start) + os.sep,
vi_mode=True, vi_mode=True,
complete_while_typing=True complete_while_typing=True,
# Test that this is not needed on Windows # Test that this is not needed on Windows
# complete_style=CompleteStyle.READLINE_LIKE, # complete_style=CompleteStyle.READLINE_LIKE,
) )

View File

@ -661,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
field_kind = ( field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None) field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
if field.json_schema_extra
else None
) )
# must have a field_kind # must have a field_kind

View File

@ -90,20 +90,23 @@ class ImageRecordDeleteException(Exception):
IMAGE_DTO_COLS = ", ".join( IMAGE_DTO_COLS = ", ".join(
["images." + c for c in [ [
"image_name", "images." + c
"image_origin", for c in [
"image_category", "image_name",
"width", "image_origin",
"height", "image_category",
"session_id", "width",
"node_id", "height",
"is_intermediate", "session_id",
"created_at", "node_id",
"updated_at", "is_intermediate",
"deleted_at", "created_at",
"starred", "updated_at",
]] "deleted_at",
"starred",
]
]
) )

View File

@ -217,13 +217,16 @@ class ImageService(ImageServiceABC):
board_id, board_id,
) )
image_dtos = [image_record_to_dto( image_dtos = [
image_record=r, image_record_to_dto(
image_url=self.__invoker.services.urls.get_image_url(r.image_name), image_record=r,
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), image_url=self.__invoker.services.urls.get_image_url(r.image_name),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name), board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
) for r in results.items] workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
)
for r in results.items
]
return OffsetPaginatedResults[ImageDTO]( return OffsetPaginatedResults[ImageDTO](
items=image_dtos, items=image_dtos,

View File

@ -1,5 +1,5 @@
from abc import ABC from abc import ABC
class InvocationProcessorABC(ABC): # noqa: B024 class InvocationProcessorABC(ABC): # noqa: B024
pass pass

View File

@ -34,7 +34,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
name="session_processor", name="session_processor",
target=self.__process, target=self.__process,
kwargs={ kwargs={
"stop_event": self.__stop_event, "poll_now_event": self.__poll_now_event, "resume_event": self.__resume_event "stop_event": self.__stop_event,
"poll_now_event": self.__poll_now_event,
"resume_event": self.__resume_event,
}, },
) )
self.__thread.start() self.__thread.start()

View File

@ -728,9 +728,9 @@ class Graph(BaseModel):
# Validate that all inputs are derived from or match a single type # Validate that all inputs are derived from or match a single type
input_field_types = { input_field_types = {
t t
for input_field in input_fields for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field)) for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType if t != NoneType
} # Get unique types } # Get unique types
type_tree = nx.DiGraph() type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types) type_tree.add_nodes_from(input_field_types)
@ -1053,7 +1053,10 @@ class GraphExecutionState(BaseModel):
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none # TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat() eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
# Create execution node for each iteration # Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings: for iteration_mappings in prepared_parent_mappings:

View File

@ -253,13 +253,13 @@ class ModelInstall(object):
# folders style or similar # folders style or similar
elif path.is_dir() and any( elif path.is_dir() and any(
(path / x).exists() (path / x).exists()
for x in { for x in {
"config.json", "config.json",
"model_index.json", "model_index.json",
"learned_embeds.bin", "learned_embeds.bin",
"pytorch_lora_weights.bin", "pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors", "pytorch_lora_weights.safetensors",
} }
): ):
models_installed.update({str(model_path_id_or_url): self._install_path(path)}) models_installed.update({str(model_path_id_or_url): self._install_path(path)})

View File

@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights) assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True): for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match. # The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match. # The token_len dimensions should match.

View File

@ -66,11 +66,13 @@ class CacheStats(object):
class ModelLocker(object): class ModelLocker(object):
"Forward declaration" "Forward declaration"
pass pass
class ModelCache(object): class ModelCache(object):
"Forward declaration" "Forward declaration"
pass pass

View File

@ -70,13 +70,13 @@ class ModelSearch(ABC):
continue continue
if any( if any(
(path / x).exists() (path / x).exists()
for x in { for x in {
"config.json", "config.json",
"model_index.json", "model_index.json",
"learned_embeds.bin", "learned_embeds.bin",
"pytorch_lora_weights.bin", "pytorch_lora_weights.bin",
"image_encoder.txt", "image_encoder.txt",
} }
): ):
try: try:
self.on_model_found(path) self.on_model_found(path)

View File

@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional. after generation completes. Optional.
""" """
attention_map_saver: Optional[AttentionMapSaver] attention_map_saver: Optional[AttentionMapSaver]

View File

@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier module.identifier = identifier
try: try:
module.set_attention_slice_wrangler(attention_slice_wrangler) module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023 module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e: except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"): if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO

View File

@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent:
deltas = None deltas = None
uncond_latents = None uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)] weighted_cond_list = (
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
)
# below is fugly omg # below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list] conditionings = [uc] + [c for c, weight in weighted_cond_list]

View File

@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
controlnet_down_block_res_samples = () controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks, strict=True): for down_block_res_sample, controlnet_block in zip(
down_block_res_samples, self.controlnet_down_blocks, strict=True
):
down_block_res_sample = controlnet_block(down_block_res_sample) down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)] down_block_res_samples = [
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)
]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else: else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]

View File

@ -229,7 +229,11 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
def tile_grads(slice1, slice2): def tile_grads(slice1, slice2):
return gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) return (
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
def dot(grad, shift): def dot(grad, shift):
return ( return (

View File

@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def __init__(self, parentApp, name, multipage=False, *args, **keywords): def __init__(self, parentApp, name, multipage=False, *args, **keywords):
self.multipage = multipage self.multipage = multipage
self.subprocess = None self.subprocess = None
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad? super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
def create(self): def create(self):
self.keypress_timeout = 10 self.keypress_timeout = 10

View File

@ -6,5 +6,7 @@ import warnings
from invokeai.frontend.install.invokeai_configure import invokeai_configure as configure from invokeai.frontend.install.invokeai_configure import invokeai_configure as configure
if __name__ == "__main__": if __name__ == "__main__":
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2) warnings.warn(
"configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2
)
configure() configure()