Added class PromptsFromFileInvocation to prompt.py. A new PromptFromFile Custom node that reads prompts from a file one line per prompt and outputs them as a prompt collection. With inputs for filename, pre_prompt, post_prompt, start line number, and max_prompts

This commit is contained in:
skunkworxdark 2023-07-15 12:17:18 +01:00 committed by Kent Keirsey
parent ed88e72412
commit 956011066d

View File

@ -1,5 +1,6 @@
from typing import Literal from typing import Literal
import numpy as np
from pydantic.fields import Field from pydantic.fields import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
@ -55,3 +56,35 @@ class DynamicPromptInvocation(BaseInvocation):
prompts = generator.generate(self.prompt, num_images=self.max_prompts) prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
class PromptsFromFileInvocation(BaseInvocation):
'''Loads prompts from a text file'''
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs
filename: str = Field(default=None, description="Filename of prompt text file")
pre_prompt: str = Field(default=None, description="Add to start of prompt")
post_prompt: str = Field(default=None, description="Add to end of prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file start start from")
max_prompts: int = Field(default=0, ge=0, description="Max lines to read from file (0=all)")
#fmt: on
def promptsFromFile(self, filename: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts:int):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(filename) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append(pre_prompt + line.strip() + post_prompt)
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.filename, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))