Fix #1599 by relaxing the match_trigger regex (#1601)

* Fix #1599 by relaxing the `match_trigger` regex

Also simplify logic and reduce duplication.

* restrict trigger regex again (but not so far)
This commit is contained in:
Damian Stewart 2022-11-28 23:58:52 +01:00 committed by GitHub
parent a514f9b236
commit 45e51bac9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 19 deletions

View File

@ -308,7 +308,7 @@ def main_loop(gen, opt):
if use_prefix is not None:
prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess'
opt.prompt = triggers_to_concepts(gen, opt.prompt) # to avoid the problem of non-unique concept triggers
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt) # to avoid the problem of non-unique concept triggers
filename, formatted_dream_prompt = prepare_image_metadata(
opt,
prefix,
@ -353,7 +353,7 @@ def main_loop(gen, opt):
if operation == 'generate':
# load any <embeddings> from the SD concepts library
opt.prompt = concepts_to_triggers(gen, opt.prompt)
opt.prompt = gen.concept_lib().replace_concepts_with_triggers(opt.prompt, lambda concepts: gen.load_concepts(concepts))
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate'
try:
@ -503,19 +503,6 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
command = '-h'
return command, operation
def concepts_to_triggers(gen, prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
gen.load_concepts(concepts)
return gen.concept_lib().replace_concepts_with_triggers(prompt)
def triggers_to_concepts(gen,prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
return gen.concept_lib().replace_triggers_with_concepts(prompt)
def set_default_output_dir(opt:Args, completer:Completer):
'''
If opt.outdir is relative, we add the root directory to it

View File

@ -7,6 +7,7 @@ The interface is through the Concepts() object.
import os
import re
import traceback
from typing import Callable
from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals
@ -22,8 +23,8 @@ class Concepts(object):
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile('(<[\w\-]+>)')
self.match_concept = re.compile('<([\w\-]+)>')
self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name
self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_-
def list_concepts(self)->list:
'''
@ -83,15 +84,27 @@ class Concepts(object):
better to store the concept name (unique) than the concept trigger
(not necessarily unique!)
'''
triggers = self.match_trigger.findall(prompt)
if not triggers:
return prompt
def do_replace(match)->str:
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(self, prompt:str)->str:
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
'''
Given a prompt string that contains <concept_name> tags, replace
Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger.
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
of `concepts_name` strings.
'''
concepts = self.match_concept.findall(prompt)
if not concepts:
return prompt
load_concepts_callback(concepts)
def do_replace(match)->str:
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt)