mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add logging, support for prompts with shell metachars
This commit is contained in:
@ -13,15 +13,17 @@ import pydoc
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import time
|
||||||
|
from contextlib import redirect_stderr
|
||||||
from io import TextIOBase
|
from io import TextIOBase
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
from multiprocessing import Process
|
||||||
|
from multiprocessing.connection import Connection, Pipe
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from multiprocessing import Process, Pipe
|
from tempfile import gettempdir
|
||||||
from multiprocessing.connection import Connection
|
from typing import Callable, Iterable, List
|
||||||
from subprocess import PIPE, Popen
|
|
||||||
from typing import Iterable, List
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf, dictconfig, listconfig
|
from omegaconf import OmegaConf, dictconfig, listconfig
|
||||||
|
|
||||||
@ -31,7 +33,7 @@ def expand_prompts(
|
|||||||
run_invoke: bool = False,
|
run_invoke: bool = False,
|
||||||
invoke_model: str = None,
|
invoke_model: str = None,
|
||||||
invoke_outdir: Path = None,
|
invoke_outdir: Path = None,
|
||||||
processes_per_gpu: int = 1
|
processes_per_gpu: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param template_file: A YAML file containing templated prompts and args
|
:param template_file: A YAML file containing templated prompts and args
|
||||||
@ -48,41 +50,58 @@ def expand_prompts(
|
|||||||
|
|
||||||
# loading here to avoid long wait for help message
|
# loading here to avoid long wait for help message
|
||||||
import torch
|
import torch
|
||||||
torch.multiprocessing.set_start_method('spawn')
|
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||||
commands = expanded_invokeai_commands(conf, run_invoke)
|
commands = expanded_invokeai_commands(conf, run_invoke)
|
||||||
children = list()
|
children = list()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if run_invoke:
|
if run_invoke:
|
||||||
# bring the big library into memory in order to share it in subprocesses
|
invokeai_args = [shutil.which("invokeai"), "--from_file", "-"]
|
||||||
import ldm.invoke.CLI
|
|
||||||
invokeai_args = [shutil.which("invokeai"),"--from_file","-"]
|
|
||||||
if invoke_model:
|
if invoke_model:
|
||||||
invokeai_args.extend(("--model", invoke_model))
|
invokeai_args.extend(("--model", invoke_model))
|
||||||
if invoke_outdir:
|
if invoke_outdir:
|
||||||
invokeai_args.extend(("--outdir", os.path.expanduser(invoke_outdir)))
|
outdir = os.path.expanduser(invoke_outdir)
|
||||||
|
invokeai_args.extend(("--outdir", outdir))
|
||||||
|
else:
|
||||||
|
outdir = gettempdir()
|
||||||
|
logdir = Path(outdir, "invokeai-batch-logs")
|
||||||
|
|
||||||
processes_to_launch = gpu_count * processes_per_gpu
|
processes_to_launch = gpu_count * processes_per_gpu
|
||||||
print(f'>> Spawning {processes_to_launch} invokeai processes across {gpu_count} CUDA gpus', file=sys.stderr)
|
print(
|
||||||
|
f">> Spawning {processes_to_launch} invokeai processes across {gpu_count} CUDA gpus",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'>> Outputs will be written into {invoke_outdir or "default InvokeAI outputs directory"}, and error logs will be written to {logdir}',
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
import ldm.invoke.CLI
|
import ldm.invoke.CLI
|
||||||
|
|
||||||
parent_conn, child_conn = Pipe()
|
parent_conn, child_conn = Pipe()
|
||||||
children = set()
|
children = set()
|
||||||
for i in range(processes_to_launch):
|
for i in range(processes_to_launch):
|
||||||
p = Process(target=_run_invoke,
|
p = Process(
|
||||||
args=(child_conn,
|
target=_run_invoke,
|
||||||
parent_conn,
|
kwargs=dict(
|
||||||
invokeai_args,
|
entry_point=ldm.invoke.CLI.main,
|
||||||
i%gpu_count,
|
conn_in=child_conn,
|
||||||
)
|
conn_out=parent_conn,
|
||||||
)
|
args=invokeai_args,
|
||||||
|
gpu=i % gpu_count,
|
||||||
|
logdir=logdir,
|
||||||
|
),
|
||||||
|
)
|
||||||
p.start()
|
p.start()
|
||||||
children.add(p)
|
children.add(p)
|
||||||
child_conn.close()
|
child_conn.close()
|
||||||
sequence = 0
|
sequence = 0
|
||||||
for command in commands:
|
for command in commands:
|
||||||
sequence += 1
|
sequence += 1
|
||||||
parent_conn.send(command+f' --fnformat="dp.{sequence:04}.{{prompt}}.png"')
|
parent_conn.send(
|
||||||
|
command + f' --fnformat="dp.{sequence:04}.{{prompt}}.png"'
|
||||||
|
)
|
||||||
parent_conn.close()
|
parent_conn.close()
|
||||||
else:
|
else:
|
||||||
for command in commands:
|
for command in commands:
|
||||||
@ -91,12 +110,13 @@ def expand_prompts(
|
|||||||
for p in children:
|
for p in children:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
|
|
||||||
class MessageToStdin(object):
|
class MessageToStdin(object):
|
||||||
def __init__(self, connection: Connection):
|
def __init__(self, connection: Connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.linebuffer = list()
|
self.linebuffer = list()
|
||||||
|
|
||||||
def readline(self)->str:
|
def readline(self) -> str:
|
||||||
try:
|
try:
|
||||||
if len(self.linebuffer) == 0:
|
if len(self.linebuffer) == 0:
|
||||||
message = self.connection.recv()
|
message = self.connection.recv()
|
||||||
@ -106,12 +126,15 @@ class MessageToStdin(object):
|
|||||||
except EOFError:
|
except EOFError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class FilterStream(object):
|
class FilterStream(object):
|
||||||
def __init__(self, stream: TextIOBase, include: re.Pattern=None, exclude: re.Pattern=None):
|
def __init__(
|
||||||
|
self, stream: TextIOBase, include: re.Pattern = None, exclude: re.Pattern = None
|
||||||
|
):
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.include = include
|
self.include = include
|
||||||
self.exclude = exclude
|
self.exclude = exclude
|
||||||
|
|
||||||
def write(self, data: str):
|
def write(self, data: str):
|
||||||
if self.include and self.include.match(data):
|
if self.include and self.include.match(data):
|
||||||
self.stream.write(data)
|
self.stream.write(data)
|
||||||
@ -122,23 +145,37 @@ class FilterStream(object):
|
|||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
self.stream.flush()
|
self.stream.flush()
|
||||||
|
|
||||||
def _run_invoke(conn_in: Connection, conn_out: Connection, args: List[str], gpu: int=0):
|
|
||||||
from ldm.invoke.CLI import main
|
def _run_invoke(
|
||||||
print(f'>> Process {os.getpid()} running on GPU {gpu}', file=sys.stderr)
|
entry_point: Callable,
|
||||||
|
conn_in: Connection,
|
||||||
|
conn_out: Connection,
|
||||||
|
args: List[str],
|
||||||
|
logdir: Path,
|
||||||
|
gpu: int = 0,
|
||||||
|
):
|
||||||
|
pid = os.getpid()
|
||||||
|
logdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logfile = Path(logdir, f'{time.strftime("%Y-%m-%d-%H:%M:%S")}-pid={pid}.txt')
|
||||||
|
print(
|
||||||
|
f">> Process {pid} running on GPU {gpu}; logging to {logfile}", file=sys.stderr
|
||||||
|
)
|
||||||
conn_out.close()
|
conn_out.close()
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = f"{gpu}"
|
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
|
||||||
sys.argv = args
|
sys.argv = args
|
||||||
sys.stdin = MessageToStdin(conn_in)
|
sys.stdin = MessageToStdin(conn_in)
|
||||||
sys.stdout = FilterStream(sys.stdout,include=re.compile('^\[\d+\]'))
|
sys.stdout = FilterStream(sys.stdout, include=re.compile("^\[\d+\]"))
|
||||||
sys.stderr = FilterStream(sys.stdout,exclude=re.compile('^(>>|\s*\d+%|Fetching)'))
|
with open(logfile, "w") as stderr, redirect_stderr(stderr):
|
||||||
main()
|
entry_point()
|
||||||
|
|
||||||
|
|
||||||
def _filter_output(stream: TextIOBase):
|
def _filter_output(stream: TextIOBase):
|
||||||
while line := stream.readline():
|
while line := stream.readline():
|
||||||
if re.match('^\[\d+\]',line):
|
if re.match("^\[\d+\]", line):
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=HELP,
|
description=HELP,
|
||||||
@ -213,31 +250,62 @@ def main():
|
|||||||
processes_per_gpu=opt.processes_per_gpu,
|
processes_per_gpu=opt.processes_per_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
def expanded_invokeai_commands(conf: OmegaConf, always_switch_models: bool=False)->List[List[str]]:
|
|
||||||
|
def expanded_invokeai_commands(
|
||||||
|
conf: OmegaConf, always_switch_models: bool = False
|
||||||
|
) -> List[List[str]]:
|
||||||
models = expand_values(conf.get("model"))
|
models = expand_values(conf.get("model"))
|
||||||
steps = expand_values(conf.get("steps")) or [30]
|
steps = expand_values(conf.get("steps")) or [30]
|
||||||
cfgs = expand_values(conf.get("cfg")) or [7.5]
|
cfgs = expand_values(conf.get("cfg")) or [7.5]
|
||||||
samplers = expand_values(conf.get("sampler")) or ["ddim"]
|
samplers = expand_values(conf.get("sampler")) or ["ddim"]
|
||||||
seeds = expand_values(conf.get("seed")) or [0]
|
seeds = expand_values(conf.get("seed")) or [0]
|
||||||
dimensions = expand_values(conf.get("dimensions")) or ["512x512"]
|
dimensions = expand_values(conf.get("dimensions")) or ["512x512"]
|
||||||
init_img = expand_values(conf.get('init_img')) or ['']
|
init_img = expand_values(conf.get("init_img")) or [""]
|
||||||
perlin = expand_values(conf.get('perlin')) or [0]
|
perlin = expand_values(conf.get("perlin")) or [0]
|
||||||
threshold = expand_values(conf.get('threshold')) or [0]
|
threshold = expand_values(conf.get("threshold")) or [0]
|
||||||
strength = expand_values(conf.get('strength')) or [0.75]
|
strength = expand_values(conf.get("strength")) or [0.75]
|
||||||
prompts = expand_prompt(conf.get("prompt")) or ["banana sushi"]
|
prompts = expand_prompt(conf.get("prompt")) or ["banana sushi"]
|
||||||
|
|
||||||
cross_product = product(
|
cross_product = product(
|
||||||
*[models, seeds, prompts, samplers, cfgs, steps, perlin, threshold, init_img, strength, dimensions]
|
*[
|
||||||
|
models,
|
||||||
|
seeds,
|
||||||
|
prompts,
|
||||||
|
samplers,
|
||||||
|
cfgs,
|
||||||
|
steps,
|
||||||
|
perlin,
|
||||||
|
threshold,
|
||||||
|
init_img,
|
||||||
|
strength,
|
||||||
|
dimensions,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
previous_model = None
|
previous_model = None
|
||||||
|
|
||||||
result = list()
|
result = list()
|
||||||
for p in cross_product:
|
for p in cross_product:
|
||||||
(model, seed, prompt, sampler, cfg, step, perlin, threshold, init_img, strength, dimensions) = tuple(p)
|
(
|
||||||
|
model,
|
||||||
|
seed,
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
cfg,
|
||||||
|
step,
|
||||||
|
perlin,
|
||||||
|
threshold,
|
||||||
|
init_img,
|
||||||
|
strength,
|
||||||
|
dimensions,
|
||||||
|
) = tuple(p)
|
||||||
(width, height) = dimensions.split("x")
|
(width, height) = dimensions.split("x")
|
||||||
switch_args = f"!switch {model}\n" if always_switch_models or previous_model != model else ''
|
switch_args = (
|
||||||
image_args = f'-I{init_img} -f{strength}' if init_img else ''
|
f"!switch {model}\n"
|
||||||
command = f'{switch_args}{prompt} -S{seed} -A{sampler} -C{cfg} -s{step} {image_args} --perlin={perlin} --threshold={threshold} -W{width} -H{height}'
|
if always_switch_models or previous_model != model
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
image_args = f"-I{init_img} -f{strength}" if init_img else ""
|
||||||
|
command = f"{switch_args}{prompt} -S{seed} -A{sampler} -C{cfg} -s{step} {image_args} --perlin={perlin} --threshold={threshold} -W{width} -H{height}"
|
||||||
result.append(command)
|
result.append(command)
|
||||||
previous_model = model
|
previous_model = model
|
||||||
return result
|
return result
|
||||||
@ -282,11 +350,19 @@ def expand_values(stanza: str | dict | listconfig.ListConfig) -> list | range:
|
|||||||
if isinstance(stanza, listconfig.ListConfig):
|
if isinstance(stanza, listconfig.ListConfig):
|
||||||
return stanza
|
return stanza
|
||||||
elif match := re.match("^(-?\d+);(-?\d+)(;(\d+))?", str(stanza)):
|
elif match := re.match("^(-?\d+);(-?\d+)(;(\d+))?", str(stanza)):
|
||||||
(start, stop, step) = (int(match.group(1)), int(match.group(2)), int(match.group(4)) or 1)
|
(start, stop, step) = (
|
||||||
return range(start, stop+step, step)
|
int(match.group(1)),
|
||||||
|
int(match.group(2)),
|
||||||
|
int(match.group(4)) or 1,
|
||||||
|
)
|
||||||
|
return range(start, stop + step, step)
|
||||||
elif match := re.match("^(-?[\d.]+);(-?[\d.]+)(;([\d.]+))?", str(stanza)):
|
elif match := re.match("^(-?[\d.]+);(-?[\d.]+)(;([\d.]+))?", str(stanza)):
|
||||||
(start, stop, step) = (float(match.group(1)), float(match.group(2)), float(match.group(4)) or 1.0)
|
(start, stop, step) = (
|
||||||
return np.arange(start, stop+step, step).tolist()
|
float(match.group(1)),
|
||||||
|
float(match.group(2)),
|
||||||
|
float(match.group(4)) or 1.0,
|
||||||
|
)
|
||||||
|
return np.arange(start, stop + step, step).tolist()
|
||||||
else:
|
else:
|
||||||
return [stanza]
|
return [stanza]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user