chore: ruff check - fix pycodestyle

This commit is contained in:
psychedelicious 2023-11-11 10:54:52 +11:00
parent 99a8ebe3a0
commit 513fceac82
7 changed files with 11 additions and 14 deletions

View File

@ -266,7 +266,7 @@ class FloatMathInvocation(BaseInvocation):
raise ValueError("Cannot divide by zero")
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
raise ValueError("Cannot raise zero to a negative power")
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
raise ValueError("Root operation resulted in a complex number")
return v

View File

@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
# sanity check make sure the graph is at least reasonably shaped
if (
type(graph) is not dict
not isinstance(graph, dict)
or "nodes" not in graph
or type(graph["nodes"]) is not dict
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or type(graph["edges"]) is not list
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None

View File

@ -88,7 +88,7 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
if type(image) is str:
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image)

View File

@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
b, L, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
return self.to_out(out)

View File

@ -642,7 +642,7 @@ class InvokeAIDiffuserComponent:
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
weighted_cond_list = c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
# below is fugly omg
conditionings = [uc] + [c for c, weight in weighted_cond_list]

View File

@ -228,11 +228,8 @@ def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10
angles = 2 * math.pi * rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
def tile_grads(slice1, slice2):
return gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
def dot(grad, shift):
return (

View File

@ -169,7 +169,7 @@ def test_prepare_values_to_insert(batch_data_collection, batch_graph):
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert type(values[0].field_values) is str
assert isinstance(values[0].field_values, str)
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority