mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff check - fix flake8-comprensions
This commit is contained in:
@ -352,7 +352,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in self.nodes.values()]
|
||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
|
||||
@ -616,7 +616,7 @@ class Graph(BaseModel):
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
@ -658,7 +658,7 @@ class Graph(BaseModel):
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = list()
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
@ -680,8 +680,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -694,7 +694,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
@ -713,8 +713,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -722,18 +722,16 @@ class Graph(BaseModel):
|
||||
outputs.append(new_output)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = set(
|
||||
[
|
||||
t
|
||||
input_field_types = {
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
} # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
@ -761,15 +759,15 @@ class Graph(BaseModel):
|
||||
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
||||
# TODO: Cache this?
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
g.add_nodes_from(list(self.nodes.keys()))
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.items()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
g.add_nodes_from(list(self.nodes.items()))
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
@ -791,7 +789,7 @@ class Graph(BaseModel):
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=dict(
|
||||
required=[
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
|
||||
source_node = self.prepared_source_mapping[node_id]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node]
|
||||
|
||||
if all([n in self.executed for n in prepared_nodes]):
|
||||
if all(n in self.executed for n in prepared_nodes):
|
||||
self.executed.add(source_node)
|
||||
self.executed_history.append(source_node)
|
||||
|
||||
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes: list[str] = list()
|
||||
new_nodes: list[str] = []
|
||||
if self_iteration_count == 0:
|
||||
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
||||
return new_nodes
|
||||
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges: list[Edge] = list()
|
||||
new_edges: list[Edge] = []
|
||||
for edge in input_edges:
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||
new_edge = Edge(
|
||||
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Create execution nodes
|
||||
next_node = self.graph.get_node(next_node_id)
|
||||
new_node_ids = list()
|
||||
new_node_ids = []
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(
|
||||
@ -1201,7 +1199,7 @@ class LibraryGraph(BaseModel):
|
||||
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
if len(v) != len({i.alias for i in v}):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
|
Reference in New Issue
Block a user