working with 1.4, 1.5, not with inpainting 1.5

This commit is contained in:
Lincoln Stein 2022-10-26 18:25:48 -04:00
parent 9b7159720f
commit 2daf187bdb
3 changed files with 28 additions and 29 deletions

View File

@ -83,11 +83,11 @@ with metadata_from_png():
import argparse import argparse
from argparse import Namespace, RawTextHelpFormatter from argparse import Namespace, RawTextHelpFormatter
import pydoc import pydoc
import shlex
import json import json
import hashlib import hashlib
import os import os
import re import re
import shlex
import copy import copy
import base64 import base64
import functools import functools
@ -169,30 +169,24 @@ class Args(object):
def parse_cmd(self,cmd_string): def parse_cmd(self,cmd_string):
'''Parse a invoke>-style command string ''' '''Parse a invoke>-style command string '''
command = cmd_string.replace("'", "\\'") # handle the case in which the prompt is enclosed by quotes
try: if cmd_string.startswith('"'):
elements = shlex.split(command) a = shlex.split(cmd_string)
elements = [x.replace("\\'","'") for x in elements] prompt = a[0]
except ValueError: switches = shlex.join(a[1:])
import sys, traceback
print(traceback.format_exc(), file=sys.stderr)
return
switches = ['']
switches_started = False
for element in elements:
if len(element) == 0: # empty prompt
pass
elif element[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(element)
else: else:
switches[0] += element # no initial quote, so get everything up to the first thing
switches[0] += ' ' # that looks like a switch
switches[0] = switches[0][: len(switches[0]) - 1] match = re.match('^(.+?)\s(--?[a-zA-Z].+)',cmd_string)
if match:
prompt,switches = match.groups()
else:
prompt = cmd_string
switches = ''
try: try:
self._cmd_switches = self._cmd_parser.parse_args(switches) self._cmd_switches = self._cmd_parser.parse_args(shlex.split(switches))
setattr(self._cmd_switches,'prompt',prompt)
return self._cmd_switches return self._cmd_switches
except: except:
return None return None
@ -213,7 +207,9 @@ class Args(object):
a = vars(self) a = vars(self)
a.update(kwargs) a.update(kwargs)
switches = list() switches = list()
switches.append(f'"{a["prompt"]}"') prompt = a['prompt']
prompt.replace('"','\\"')
switches.append(prompt)
switches.append(f'-s {a["steps"]}') switches.append(f'-s {a["steps"]}')
switches.append(f'-S {a["seed"]}') switches.append(f'-S {a["seed"]}')
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
@ -573,7 +569,11 @@ class Args(object):
variation_group = parser.add_argument_group('Creating and combining variations') variation_group = parser.add_argument_group('Creating and combining variations')
postprocessing_group = parser.add_argument_group('Post-processing') postprocessing_group = parser.add_argument_group('Post-processing')
special_effects_group = parser.add_argument_group('Special effects') special_effects_group = parser.add_argument_group('Special effects')
render_group.add_argument('prompt') render_group.add_argument(
'--prompt',
default='',
help='prompt string',
)
render_group.add_argument( render_group.add_argument(
'-s', '-s',
'--steps', '--steps',

View File

@ -827,7 +827,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_model.encode self.cond_stage_model.encode
): ):
c = self.cond_stage_model.encode( c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager, **kwargs c, embedding_manager=self.embedding_manager,**kwargs
) )
if isinstance(c, DiagonalGaussianDistribution): if isinstance(c, DiagonalGaussianDistribution):
c = c.mode() c = c.mode()

View File

@ -23,10 +23,9 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, sampler, threshold = 0, warmup = 0): def __init__(self, model, threshold = 0, warmup = 0):
super().__init__() super().__init__()
self.inner_model = sampler.model self.inner_model = model
self.sampler = sampler
self.threshold = threshold self.threshold = threshold
self.warmup_max = warmup self.warmup_max = warmup
self.warmup = max(warmup / 10, 1) self.warmup = max(warmup / 10, 1)