Fix #1599 by relaxing the match_trigger regex (#1601)

* Fix #1599 by relaxing the `match_trigger` regex

Also simplify logic and reduce duplication.

* restrict trigger regex again (but not so far)
This commit is contained in:
Damian Stewart 2022-11-28 23:58:52 +01:00 committed by GitHub
parent a514f9b236
commit 45e51bac9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 19 deletions

View File

@ -308,7 +308,7 @@ def main_loop(gen, opt):
if use_prefix is not None: if use_prefix is not None:
prefix = use_prefix prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess' postprocessed = upscaled if upscaled else operation=='postprocess'
opt.prompt = triggers_to_concepts(gen, opt.prompt) # to avoid the problem of non-unique concept triggers opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt) # to avoid the problem of non-unique concept triggers
filename, formatted_dream_prompt = prepare_image_metadata( filename, formatted_dream_prompt = prepare_image_metadata(
opt, opt,
prefix, prefix,
@ -353,7 +353,7 @@ def main_loop(gen, opt):
if operation == 'generate': if operation == 'generate':
# load any <embeddings> from the SD concepts library # load any <embeddings> from the SD concepts library
opt.prompt = concepts_to_triggers(gen, opt.prompt) opt.prompt = gen.concept_lib().replace_concepts_with_triggers(opt.prompt, lambda concepts: gen.load_concepts(concepts))
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate' opt.last_operation='generate'
try: try:
@ -503,19 +503,6 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
command = '-h' command = '-h'
return command, operation return command, operation
def concepts_to_triggers(gen, prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
gen.load_concepts(concepts)
return gen.concept_lib().replace_concepts_with_triggers(prompt)
def triggers_to_concepts(gen,prompt:str)->str:
concepts = re.findall('<([^>]+)>',prompt)
if not concepts:
return prompt
return gen.concept_lib().replace_triggers_with_concepts(prompt)
def set_default_output_dir(opt:Args, completer:Completer): def set_default_output_dir(opt:Args, completer:Completer):
''' '''
If opt.outdir is relative, we add the root directory to it If opt.outdir is relative, we add the root directory to it

View File

@ -7,6 +7,7 @@ The interface is through the Concepts() object.
import os import os
import re import re
import traceback import traceback
from typing import Callable
from urllib import request, error as ul_error from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
@ -22,8 +23,8 @@ class Concepts(object):
self.concepts_loaded = dict() self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile('(<[\w\-]+>)') self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name
self.match_concept = re.compile('<([\w\-]+)>') self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_-
def list_concepts(self)->list: def list_concepts(self)->list:
''' '''
@ -83,15 +84,27 @@ class Concepts(object):
better to store the concept name (unique) than the concept trigger better to store the concept name (unique) than the concept trigger
(not necessarily unique!) (not necessarily unique!)
''' '''
triggers = self.match_trigger.findall(prompt)
if not triggers:
return prompt
def do_replace(match)->str: def do_replace(match)->str:
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>' return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt) return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(self, prompt:str)->str: def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
''' '''
Given a prompt string that contains <concept_name> tags, replace Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger. these tags with the appropriate trigger.
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
of `concepts_name` strings.
''' '''
concepts = self.match_concept.findall(prompt)
if not concepts:
return prompt
load_concepts_callback(concepts)
def do_replace(match)->str: def do_replace(match)->str:
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>' return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt) return self.match_concept.sub(do_replace, prompt)