mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff format
This commit is contained in:
parent
513fceac82
commit
6494e8e551
@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [browse_start], # noqa: B023
|
||||
get_paths=lambda: [browse_start], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
vi_mode=True,
|
||||
complete_while_typing=True
|
||||
complete_while_typing=True,
|
||||
# Test that this is not needed on Windows
|
||||
# complete_style=CompleteStyle.READLINE_LIKE,
|
||||
)
|
||||
|
@ -661,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None)
|
||||
if field.json_schema_extra
|
||||
else None
|
||||
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
||||
)
|
||||
|
||||
# must have a field_kind
|
||||
|
@ -90,20 +90,23 @@ class ImageRecordDeleteException(Exception):
|
||||
|
||||
|
||||
IMAGE_DTO_COLS = ", ".join(
|
||||
["images." + c for c in [
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
]]
|
||||
[
|
||||
"images." + c
|
||||
for c in [
|
||||
"image_name",
|
||||
"image_origin",
|
||||
"image_category",
|
||||
"width",
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
@ -217,13 +217,16 @@ class ImageService(ImageServiceABC):
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = [image_record_to_dto(
|
||||
image_record=r,
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
) for r in results.items]
|
||||
image_dtos = [
|
||||
image_record_to_dto(
|
||||
image_record=r,
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
)
|
||||
for r in results.items
|
||||
]
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC): # noqa: B024
|
||||
class InvocationProcessorABC(ABC): # noqa: B024
|
||||
pass
|
||||
|
@ -34,7 +34,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
name="session_processor",
|
||||
target=self.__process,
|
||||
kwargs={
|
||||
"stop_event": self.__stop_event, "poll_now_event": self.__poll_now_event, "resume_event": self.__resume_event
|
||||
"stop_event": self.__stop_event,
|
||||
"poll_now_event": self.__poll_now_event,
|
||||
"resume_event": self.__resume_event,
|
||||
},
|
||||
)
|
||||
self.__thread.start()
|
||||
|
@ -728,9 +728,9 @@ class Graph(BaseModel):
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = {
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
} # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
@ -1053,7 +1053,10 @@ class GraphExecutionState(BaseModel):
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
# TODO: Handle a node mapping to none
|
||||
eg = self.execution_graph.nx_graph_flat()
|
||||
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
|
||||
prepared_parent_mappings = [
|
||||
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
|
||||
for it in iterator_node_prepared_combinations
|
||||
] # type: ignore
|
||||
|
||||
# Create execution node for each iteration
|
||||
for iteration_mappings in prepared_parent_mappings:
|
||||
|
@ -253,13 +253,13 @@ class ModelInstall(object):
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any(
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
|
@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||
|
||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True):
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
|
@ -66,11 +66,13 @@ class CacheStats(object):
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -70,13 +70,13 @@ class ModelSearch(ABC):
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
|
@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
|
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
|
@ -642,7 +642,9 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
|
@ -732,7 +732,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks, strict=True):
|
||||
for down_block_res_sample, controlnet_block in zip(
|
||||
down_block_res_samples, self.controlnet_down_blocks, strict=True
|
||||
):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
@ -745,7 +747,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)]
|
||||
down_block_res_samples = [
|
||||
sample * scale for sample, scale in zip(down_block_res_samples, scales, strict=True)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
|
@ -229,7 +229,11 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
||||
|
||||
def tile_grads(slice1, slice2):
|
||||
return gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
||||
return (
|
||||
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
|
||||
def dot(grad, shift):
|
||||
return (
|
||||
|
@ -72,7 +72,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords) # noqa: B026 # TODO: maybe this is bad?
|
||||
|
||||
def create(self):
|
||||
self.keypress_timeout = 10
|
||||
|
@ -6,5 +6,7 @@ import warnings
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure as configure
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2)
|
||||
warnings.warn(
|
||||
"configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
configure()
|
||||
|
Loading…
Reference in New Issue
Block a user