Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -32,9 +32,7 @@ def log_txt_as_img(wh, xc, size=10):
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
@ -81,9 +79,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
@ -154,21 +150,12 @@ def parallel_data_prefetch(
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i : i + step] for i in range(0, len(data), step)]
)
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
]
processes = []
for i in range(n_proc):
@ -220,9 +207,7 @@ def parallel_data_prefetch(
return gather_res
def rand_perlin_2d(
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
):
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
@ -265,9 +250,9 @@ def rand_perlin_2d(
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
t = fade(grid[: shape[0], : shape[1]])
noise = math.sqrt(2) * torch.lerp(
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
).to(device)
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(
device
)
return noise.to(dtype=torch_dtype(device))
@ -276,9 +261,7 @@ def ask_user(question: str, answers: list):
user_prompt = f"\n>> {question} {answers}: "
invalid_answer_msg = "Invalid answer. Please try again."
pose_question = chain(
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
)
pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])))
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
@ -303,9 +286,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
if dest.is_dir():
try:
file_name = re.search(
'filename="(.+)"', resp.headers.get("Content-Disposition")
).group(1)
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
file_name = os.path.basename(url)
dest = dest / file_name
@ -322,7 +303,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.")
return dest
@ -377,16 +358,16 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8")
return image_base64
class Chdir(object):
'''Context manager to chdir to desired directory and change back after context exits:
"""Context manager to chdir to desired directory and change back after context exits:
Args:
path (Path): The path to the cwd
'''
"""
def __init__(self, path: Path):
self.path = path
self.original = Path().absolute()
@ -394,5 +375,5 @@ class Chdir(object):
def __enter__(self):
os.chdir(self.path)
def __exit__(self,*args):
def __exit__(self, *args):
os.chdir(self.original)