diff --git a/ldm/invoke/dynamic_prompts.py b/ldm/invoke/dynamic_prompts.py index 9a16882da0..c196ce7c33 100755 --- a/ldm/invoke/dynamic_prompts.py +++ b/ldm/invoke/dynamic_prompts.py @@ -13,15 +13,17 @@ import pydoc import re import shutil import sys -import numpy as np +import time +from contextlib import redirect_stderr from io import TextIOBase from itertools import product +from multiprocessing import Process +from multiprocessing.connection import Connection, Pipe from pathlib import Path -from multiprocessing import Process, Pipe -from multiprocessing.connection import Connection -from subprocess import PIPE, Popen -from typing import Iterable, List +from tempfile import gettempdir +from typing import Callable, Iterable, List +import numpy as np import yaml from omegaconf import OmegaConf, dictconfig, listconfig @@ -31,7 +33,7 @@ def expand_prompts( run_invoke: bool = False, invoke_model: str = None, invoke_outdir: Path = None, - processes_per_gpu: int = 1 + processes_per_gpu: int = 1, ): """ :param template_file: A YAML file containing templated prompts and args @@ -48,41 +50,58 @@ def expand_prompts( # loading here to avoid long wait for help message import torch - torch.multiprocessing.set_start_method('spawn') + + torch.multiprocessing.set_start_method("spawn") gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 1 commands = expanded_invokeai_commands(conf, run_invoke) children = list() - + try: if run_invoke: - # bring the big library into memory in order to share it in subprocesses - import ldm.invoke.CLI - invokeai_args = [shutil.which("invokeai"),"--from_file","-"] + invokeai_args = [shutil.which("invokeai"), "--from_file", "-"] if invoke_model: invokeai_args.extend(("--model", invoke_model)) if invoke_outdir: - invokeai_args.extend(("--outdir", os.path.expanduser(invoke_outdir))) + outdir = os.path.expanduser(invoke_outdir) + invokeai_args.extend(("--outdir", outdir)) + else: + outdir = gettempdir() + logdir = Path(outdir, "invokeai-batch-logs") processes_to_launch = gpu_count * processes_per_gpu - print(f'>> Spawning {processes_to_launch} invokeai processes across {gpu_count} CUDA gpus', file=sys.stderr) + print( + f">> Spawning {processes_to_launch} invokeai processes across {gpu_count} CUDA gpus", + file=sys.stderr, + ) + print( + f'>> Outputs will be written into {invoke_outdir or "default InvokeAI outputs directory"}, and error logs will be written to {logdir}', + file=sys.stderr, + ) import ldm.invoke.CLI + parent_conn, child_conn = Pipe() children = set() for i in range(processes_to_launch): - p = Process(target=_run_invoke, - args=(child_conn, - parent_conn, - invokeai_args, - i%gpu_count, - ) - ) + p = Process( + target=_run_invoke, + kwargs=dict( + entry_point=ldm.invoke.CLI.main, + conn_in=child_conn, + conn_out=parent_conn, + args=invokeai_args, + gpu=i % gpu_count, + logdir=logdir, + ), + ) p.start() children.add(p) child_conn.close() sequence = 0 for command in commands: sequence += 1 - parent_conn.send(command+f' --fnformat="dp.{sequence:04}.{{prompt}}.png"') + parent_conn.send( + command + f' --fnformat="dp.{sequence:04}.{{prompt}}.png"' + ) parent_conn.close() else: for command in commands: @@ -91,12 +110,13 @@ def expand_prompts( for p in children: p.terminate() + class MessageToStdin(object): def __init__(self, connection: Connection): self.connection = connection self.linebuffer = list() - def readline(self)->str: + def readline(self) -> str: try: if len(self.linebuffer) == 0: message = self.connection.recv() @@ -106,12 +126,15 @@ class MessageToStdin(object): except EOFError: return None + class FilterStream(object): - def __init__(self, stream: TextIOBase, include: re.Pattern=None, exclude: re.Pattern=None): + def __init__( + self, stream: TextIOBase, include: re.Pattern = None, exclude: re.Pattern = None + ): self.stream = stream self.include = include self.exclude = exclude - + def write(self, data: str): if self.include and self.include.match(data): self.stream.write(data) @@ -122,23 +145,37 @@ class FilterStream(object): def flush(self): self.stream.flush() - -def _run_invoke(conn_in: Connection, conn_out: Connection, args: List[str], gpu: int=0): - from ldm.invoke.CLI import main - print(f'>> Process {os.getpid()} running on GPU {gpu}', file=sys.stderr) + + +def _run_invoke( + entry_point: Callable, + conn_in: Connection, + conn_out: Connection, + args: List[str], + logdir: Path, + gpu: int = 0, +): + pid = os.getpid() + logdir.mkdir(parents=True, exist_ok=True) + logfile = Path(logdir, f'{time.strftime("%Y-%m-%d-%H:%M:%S")}-pid={pid}.txt') + print( + f">> Process {pid} running on GPU {gpu}; logging to {logfile}", file=sys.stderr + ) conn_out.close() - os.environ['CUDA_VISIBLE_DEVICES'] = f"{gpu}" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" sys.argv = args sys.stdin = MessageToStdin(conn_in) - sys.stdout = FilterStream(sys.stdout,include=re.compile('^\[\d+\]')) - sys.stderr = FilterStream(sys.stdout,exclude=re.compile('^(>>|\s*\d+%|Fetching)')) - main() + sys.stdout = FilterStream(sys.stdout, include=re.compile("^\[\d+\]")) + with open(logfile, "w") as stderr, redirect_stderr(stderr): + entry_point() + def _filter_output(stream: TextIOBase): while line := stream.readline(): - if re.match('^\[\d+\]',line): + if re.match("^\[\d+\]", line): print(line) - + + def main(): parser = argparse.ArgumentParser( description=HELP, @@ -213,31 +250,62 @@ def main(): processes_per_gpu=opt.processes_per_gpu, ) -def expanded_invokeai_commands(conf: OmegaConf, always_switch_models: bool=False)->List[List[str]]: + +def expanded_invokeai_commands( + conf: OmegaConf, always_switch_models: bool = False +) -> List[List[str]]: models = expand_values(conf.get("model")) steps = expand_values(conf.get("steps")) or [30] cfgs = expand_values(conf.get("cfg")) or [7.5] samplers = expand_values(conf.get("sampler")) or ["ddim"] seeds = expand_values(conf.get("seed")) or [0] dimensions = expand_values(conf.get("dimensions")) or ["512x512"] - init_img = expand_values(conf.get('init_img')) or [''] - perlin = expand_values(conf.get('perlin')) or [0] - threshold = expand_values(conf.get('threshold')) or [0] - strength = expand_values(conf.get('strength')) or [0.75] + init_img = expand_values(conf.get("init_img")) or [""] + perlin = expand_values(conf.get("perlin")) or [0] + threshold = expand_values(conf.get("threshold")) or [0] + strength = expand_values(conf.get("strength")) or [0.75] prompts = expand_prompt(conf.get("prompt")) or ["banana sushi"] cross_product = product( - *[models, seeds, prompts, samplers, cfgs, steps, perlin, threshold, init_img, strength, dimensions] + *[ + models, + seeds, + prompts, + samplers, + cfgs, + steps, + perlin, + threshold, + init_img, + strength, + dimensions, + ] ) previous_model = None - + result = list() for p in cross_product: - (model, seed, prompt, sampler, cfg, step, perlin, threshold, init_img, strength, dimensions) = tuple(p) + ( + model, + seed, + prompt, + sampler, + cfg, + step, + perlin, + threshold, + init_img, + strength, + dimensions, + ) = tuple(p) (width, height) = dimensions.split("x") - switch_args = f"!switch {model}\n" if always_switch_models or previous_model != model else '' - image_args = f'-I{init_img} -f{strength}' if init_img else '' - command = f'{switch_args}{prompt} -S{seed} -A{sampler} -C{cfg} -s{step} {image_args} --perlin={perlin} --threshold={threshold} -W{width} -H{height}' + switch_args = ( + f"!switch {model}\n" + if always_switch_models or previous_model != model + else "" + ) + image_args = f"-I{init_img} -f{strength}" if init_img else "" + command = f"{switch_args}{prompt} -S{seed} -A{sampler} -C{cfg} -s{step} {image_args} --perlin={perlin} --threshold={threshold} -W{width} -H{height}" result.append(command) previous_model = model return result @@ -282,11 +350,19 @@ def expand_values(stanza: str | dict | listconfig.ListConfig) -> list | range: if isinstance(stanza, listconfig.ListConfig): return stanza elif match := re.match("^(-?\d+);(-?\d+)(;(\d+))?", str(stanza)): - (start, stop, step) = (int(match.group(1)), int(match.group(2)), int(match.group(4)) or 1) - return range(start, stop+step, step) + (start, stop, step) = ( + int(match.group(1)), + int(match.group(2)), + int(match.group(4)) or 1, + ) + return range(start, stop + step, step) elif match := re.match("^(-?[\d.]+);(-?[\d.]+)(;([\d.]+))?", str(stanza)): - (start, stop, step) = (float(match.group(1)), float(match.group(2)), float(match.group(4)) or 1.0) - return np.arange(start, stop+step, step).tolist() + (start, stop, step) = ( + float(match.group(1)), + float(match.group(2)), + float(match.group(4)) or 1.0, + ) + return np.arange(start, stop + step, step).tolist() else: return [stanza]