add logging, support for prompts with shell metachars

This commit is contained in:
Lincoln Stein
2023-03-09 11:57:44 -05:00
parent 84dfd2003e
commit 142ba8c8ea

View File

@ -13,15 +13,17 @@ import pydoc
import re import re
import shutil import shutil
import sys import sys
import numpy as np import time
from contextlib import redirect_stderr
from io import TextIOBase from io import TextIOBase
from itertools import product from itertools import product
from multiprocessing import Process
from multiprocessing.connection import Connection, Pipe
from pathlib import Path from pathlib import Path
from multiprocessing import Process, Pipe from tempfile import gettempdir
from multiprocessing.connection import Connection from typing import Callable, Iterable, List
from subprocess import PIPE, Popen
from typing import Iterable, List
import numpy as np
import yaml import yaml
from omegaconf import OmegaConf, dictconfig, listconfig from omegaconf import OmegaConf, dictconfig, listconfig
@ -31,7 +33,7 @@ def expand_prompts(
run_invoke: bool = False, run_invoke: bool = False,
invoke_model: str = None, invoke_model: str = None,
invoke_outdir: Path = None, invoke_outdir: Path = None,
processes_per_gpu: int = 1 processes_per_gpu: int = 1,
): ):
""" """
:param template_file: A YAML file containing templated prompts and args :param template_file: A YAML file containing templated prompts and args
@ -48,41 +50,58 @@ def expand_prompts(
# loading here to avoid long wait for help message # loading here to avoid long wait for help message
import torch import torch
torch.multiprocessing.set_start_method('spawn')
torch.multiprocessing.set_start_method("spawn")
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 1 gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
commands = expanded_invokeai_commands(conf, run_invoke) commands = expanded_invokeai_commands(conf, run_invoke)
children = list() children = list()
try: try:
if run_invoke: if run_invoke:
# bring the big library into memory in order to share it in subprocesses invokeai_args = [shutil.which("invokeai"), "--from_file", "-"]
import ldm.invoke.CLI
invokeai_args = [shutil.which("invokeai"),"--from_file","-"]
if invoke_model: if invoke_model:
invokeai_args.extend(("--model", invoke_model)) invokeai_args.extend(("--model", invoke_model))
if invoke_outdir: if invoke_outdir:
invokeai_args.extend(("--outdir", os.path.expanduser(invoke_outdir))) outdir = os.path.expanduser(invoke_outdir)
invokeai_args.extend(("--outdir", outdir))
else:
outdir = gettempdir()
logdir = Path(outdir, "invokeai-batch-logs")
processes_to_launch = gpu_count * processes_per_gpu processes_to_launch = gpu_count * processes_per_gpu
print(f'>> Spawning {processes_to_launch} invokeai processes across {gpu_count} CUDA gpus', file=sys.stderr) print(
f">> Spawning {processes_to_launch} invokeai processes across {gpu_count} CUDA gpus",
file=sys.stderr,
)
print(
f'>> Outputs will be written into {invoke_outdir or "default InvokeAI outputs directory"}, and error logs will be written to {logdir}',
file=sys.stderr,
)
import ldm.invoke.CLI import ldm.invoke.CLI
parent_conn, child_conn = Pipe() parent_conn, child_conn = Pipe()
children = set() children = set()
for i in range(processes_to_launch): for i in range(processes_to_launch):
p = Process(target=_run_invoke, p = Process(
args=(child_conn, target=_run_invoke,
parent_conn, kwargs=dict(
invokeai_args, entry_point=ldm.invoke.CLI.main,
i%gpu_count, conn_in=child_conn,
) conn_out=parent_conn,
) args=invokeai_args,
gpu=i % gpu_count,
logdir=logdir,
),
)
p.start() p.start()
children.add(p) children.add(p)
child_conn.close() child_conn.close()
sequence = 0 sequence = 0
for command in commands: for command in commands:
sequence += 1 sequence += 1
parent_conn.send(command+f' --fnformat="dp.{sequence:04}.{{prompt}}.png"') parent_conn.send(
command + f' --fnformat="dp.{sequence:04}.{{prompt}}.png"'
)
parent_conn.close() parent_conn.close()
else: else:
for command in commands: for command in commands:
@ -91,12 +110,13 @@ def expand_prompts(
for p in children: for p in children:
p.terminate() p.terminate()
class MessageToStdin(object): class MessageToStdin(object):
def __init__(self, connection: Connection): def __init__(self, connection: Connection):
self.connection = connection self.connection = connection
self.linebuffer = list() self.linebuffer = list()
def readline(self)->str: def readline(self) -> str:
try: try:
if len(self.linebuffer) == 0: if len(self.linebuffer) == 0:
message = self.connection.recv() message = self.connection.recv()
@ -106,12 +126,15 @@ class MessageToStdin(object):
except EOFError: except EOFError:
return None return None
class FilterStream(object): class FilterStream(object):
def __init__(self, stream: TextIOBase, include: re.Pattern=None, exclude: re.Pattern=None): def __init__(
self, stream: TextIOBase, include: re.Pattern = None, exclude: re.Pattern = None
):
self.stream = stream self.stream = stream
self.include = include self.include = include
self.exclude = exclude self.exclude = exclude
def write(self, data: str): def write(self, data: str):
if self.include and self.include.match(data): if self.include and self.include.match(data):
self.stream.write(data) self.stream.write(data)
@ -122,23 +145,37 @@ class FilterStream(object):
def flush(self): def flush(self):
self.stream.flush() self.stream.flush()
def _run_invoke(conn_in: Connection, conn_out: Connection, args: List[str], gpu: int=0):
from ldm.invoke.CLI import main def _run_invoke(
print(f'>> Process {os.getpid()} running on GPU {gpu}', file=sys.stderr) entry_point: Callable,
conn_in: Connection,
conn_out: Connection,
args: List[str],
logdir: Path,
gpu: int = 0,
):
pid = os.getpid()
logdir.mkdir(parents=True, exist_ok=True)
logfile = Path(logdir, f'{time.strftime("%Y-%m-%d-%H:%M:%S")}-pid={pid}.txt')
print(
f">> Process {pid} running on GPU {gpu}; logging to {logfile}", file=sys.stderr
)
conn_out.close() conn_out.close()
os.environ['CUDA_VISIBLE_DEVICES'] = f"{gpu}" os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
sys.argv = args sys.argv = args
sys.stdin = MessageToStdin(conn_in) sys.stdin = MessageToStdin(conn_in)
sys.stdout = FilterStream(sys.stdout,include=re.compile('^\[\d+\]')) sys.stdout = FilterStream(sys.stdout, include=re.compile("^\[\d+\]"))
sys.stderr = FilterStream(sys.stdout,exclude=re.compile('^(>>|\s*\d+%|Fetching)')) with open(logfile, "w") as stderr, redirect_stderr(stderr):
main() entry_point()
def _filter_output(stream: TextIOBase): def _filter_output(stream: TextIOBase):
while line := stream.readline(): while line := stream.readline():
if re.match('^\[\d+\]',line): if re.match("^\[\d+\]", line):
print(line) print(line)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=HELP, description=HELP,
@ -213,31 +250,62 @@ def main():
processes_per_gpu=opt.processes_per_gpu, processes_per_gpu=opt.processes_per_gpu,
) )
def expanded_invokeai_commands(conf: OmegaConf, always_switch_models: bool=False)->List[List[str]]:
def expanded_invokeai_commands(
conf: OmegaConf, always_switch_models: bool = False
) -> List[List[str]]:
models = expand_values(conf.get("model")) models = expand_values(conf.get("model"))
steps = expand_values(conf.get("steps")) or [30] steps = expand_values(conf.get("steps")) or [30]
cfgs = expand_values(conf.get("cfg")) or [7.5] cfgs = expand_values(conf.get("cfg")) or [7.5]
samplers = expand_values(conf.get("sampler")) or ["ddim"] samplers = expand_values(conf.get("sampler")) or ["ddim"]
seeds = expand_values(conf.get("seed")) or [0] seeds = expand_values(conf.get("seed")) or [0]
dimensions = expand_values(conf.get("dimensions")) or ["512x512"] dimensions = expand_values(conf.get("dimensions")) or ["512x512"]
init_img = expand_values(conf.get('init_img')) or [''] init_img = expand_values(conf.get("init_img")) or [""]
perlin = expand_values(conf.get('perlin')) or [0] perlin = expand_values(conf.get("perlin")) or [0]
threshold = expand_values(conf.get('threshold')) or [0] threshold = expand_values(conf.get("threshold")) or [0]
strength = expand_values(conf.get('strength')) or [0.75] strength = expand_values(conf.get("strength")) or [0.75]
prompts = expand_prompt(conf.get("prompt")) or ["banana sushi"] prompts = expand_prompt(conf.get("prompt")) or ["banana sushi"]
cross_product = product( cross_product = product(
*[models, seeds, prompts, samplers, cfgs, steps, perlin, threshold, init_img, strength, dimensions] *[
models,
seeds,
prompts,
samplers,
cfgs,
steps,
perlin,
threshold,
init_img,
strength,
dimensions,
]
) )
previous_model = None previous_model = None
result = list() result = list()
for p in cross_product: for p in cross_product:
(model, seed, prompt, sampler, cfg, step, perlin, threshold, init_img, strength, dimensions) = tuple(p) (
model,
seed,
prompt,
sampler,
cfg,
step,
perlin,
threshold,
init_img,
strength,
dimensions,
) = tuple(p)
(width, height) = dimensions.split("x") (width, height) = dimensions.split("x")
switch_args = f"!switch {model}\n" if always_switch_models or previous_model != model else '' switch_args = (
image_args = f'-I{init_img} -f{strength}' if init_img else '' f"!switch {model}\n"
command = f'{switch_args}{prompt} -S{seed} -A{sampler} -C{cfg} -s{step} {image_args} --perlin={perlin} --threshold={threshold} -W{width} -H{height}' if always_switch_models or previous_model != model
else ""
)
image_args = f"-I{init_img} -f{strength}" if init_img else ""
command = f"{switch_args}{prompt} -S{seed} -A{sampler} -C{cfg} -s{step} {image_args} --perlin={perlin} --threshold={threshold} -W{width} -H{height}"
result.append(command) result.append(command)
previous_model = model previous_model = model
return result return result
@ -282,11 +350,19 @@ def expand_values(stanza: str | dict | listconfig.ListConfig) -> list | range:
if isinstance(stanza, listconfig.ListConfig): if isinstance(stanza, listconfig.ListConfig):
return stanza return stanza
elif match := re.match("^(-?\d+);(-?\d+)(;(\d+))?", str(stanza)): elif match := re.match("^(-?\d+);(-?\d+)(;(\d+))?", str(stanza)):
(start, stop, step) = (int(match.group(1)), int(match.group(2)), int(match.group(4)) or 1) (start, stop, step) = (
return range(start, stop+step, step) int(match.group(1)),
int(match.group(2)),
int(match.group(4)) or 1,
)
return range(start, stop + step, step)
elif match := re.match("^(-?[\d.]+);(-?[\d.]+)(;([\d.]+))?", str(stanza)): elif match := re.match("^(-?[\d.]+);(-?[\d.]+)(;([\d.]+))?", str(stanza)):
(start, stop, step) = (float(match.group(1)), float(match.group(2)), float(match.group(4)) or 1.0) (start, stop, step) = (
return np.arange(start, stop+step, step).tolist() float(match.group(1)),
float(match.group(2)),
float(match.group(4)) or 1.0,
)
return np.arange(start, stop + step, step).tolist()
else: else:
return [stanza] return [stanza]