fix mishandling of embedded quotes in prompt

This commit is contained in:
Lincoln Stein 2022-10-26 18:27:35 -04:00
parent 2b6d78e436
commit 3f77b68a9d

View File

@ -83,16 +83,16 @@ with metadata_from_png():
import argparse import argparse
from argparse import Namespace, RawTextHelpFormatter from argparse import Namespace, RawTextHelpFormatter
import pydoc import pydoc
import shlex
import json import json
import hashlib import hashlib
import os import os
import re import re
import shlex
import copy import copy
import base64 import base64
import functools import functools
import ldm.invoke.pngwriter import ldm.invoke.pngwriter
from ldm.invoke.conditioning import split_weighted_subprompts from ldm.invoke.prompt_parser import split_weighted_subprompts
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
'ddim', 'ddim',
@ -169,28 +169,24 @@ class Args(object):
def parse_cmd(self,cmd_string): def parse_cmd(self,cmd_string):
'''Parse a invoke>-style command string ''' '''Parse a invoke>-style command string '''
command = cmd_string.replace("'", "\\'") # handle the case in which the prompt is enclosed by quotes
try: if cmd_string.startswith('"'):
elements = shlex.split(command) a = shlex.split(cmd_string)
elements = [x.replace("\\'","'") for x in elements] prompt = a[0]
except ValueError: switches = shlex.join(a[1:])
import sys, traceback
print(traceback.format_exc(), file=sys.stderr)
return
switches = ['']
switches_started = False
for element in elements:
if element[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(element)
else: else:
switches[0] += element # no initial quote, so get everything up to the first thing
switches[0] += ' ' # that looks like a switch
switches[0] = switches[0][: len(switches[0]) - 1] match = re.match('^(.+?)\s(--?[a-zA-Z].+)',cmd_string)
if match:
prompt,switches = match.groups()
else:
prompt = cmd_string
switches = ''
try: try:
self._cmd_switches = self._cmd_parser.parse_args(switches) self._cmd_switches = self._cmd_parser.parse_args(shlex.split(switches))
setattr(self._cmd_switches,'prompt',prompt)
return self._cmd_switches return self._cmd_switches
except: except:
return None return None
@ -211,7 +207,9 @@ class Args(object):
a = vars(self) a = vars(self)
a.update(kwargs) a.update(kwargs)
switches = list() switches = list()
switches.append(f'"{a["prompt"]}"') prompt = a['prompt']
prompt.replace('"','\\"')
switches.append(prompt)
switches.append(f'-s {a["steps"]}') switches.append(f'-s {a["steps"]}')
switches.append(f'-S {a["seed"]}') switches.append(f'-S {a["seed"]}')
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
@ -571,7 +569,11 @@ class Args(object):
variation_group = parser.add_argument_group('Creating and combining variations') variation_group = parser.add_argument_group('Creating and combining variations')
postprocessing_group = parser.add_argument_group('Post-processing') postprocessing_group = parser.add_argument_group('Post-processing')
special_effects_group = parser.add_argument_group('Special effects') special_effects_group = parser.add_argument_group('Special effects')
render_group.add_argument('prompt') render_group.add_argument(
'--prompt',
default='',
help='prompt string',
)
render_group.add_argument( render_group.add_argument(
'-s', '-s',
'--steps', '--steps',