convert remainder of print() to log.info()

This commit is contained in:
Lincoln Stein
2023-04-14 15:15:14 -04:00
parent c132dbdefa
commit 0b0e6fe448
22 changed files with 262 additions and 259 deletions

View File

@ -18,6 +18,7 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import invokeai.backend.util.logging as log
from .devices import torch_dtype
@ -38,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
log.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -80,8 +81,8 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
log.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
@ -132,8 +133,8 @@ def parallel_data_prefetch(
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
log.warning(
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
@ -175,7 +176,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print("Start prefetching...")
log.info("Start prefetching...")
import time
start = time.time()
@ -194,7 +195,7 @@ def parallel_data_prefetch(
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
log.error("Exception: ", e)
for p in processes:
p.terminate()
@ -202,7 +203,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
log.info(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
@ -318,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
resp = requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
print("* corrupt existing file found. re-downloading")
log.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or exist_size == content_length:
print(f"* {dest}: complete file found. Skipping.")
log.warning(f"{dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
log.warning(f"{dest}: partial file found. Resuming...")
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
log.error(f"An error occurred during downloading {dest}: {resp.reason}")
else:
print(f"* {dest}: Downloading...")
log.error(f"{dest}: Downloading...")
try:
if content_length < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
log.error(f"ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
@ -349,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {dest}: {str(e)}")
log.error(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest