apply changes as suggested @psychedelicious in PR comments.

- filename -> file_path
- pre and post prompt changed to optional
- clearer pre and post prompt descriptions
- handle pre and post prompt passed as None
- max_prompts defaults to 1 isted of 0 to avoid accidentally processing large prompt files with it set to 0 when adding a new node.
This commit is contained in:
skunkworxdark 2023-07-16 14:58:55 +01:00 committed by Kent Keirsey
parent 956011066d
commit b1008af696

View File

@ -1,7 +1,8 @@
from typing import Literal from os.path import exists
from typing import Literal, Optional
import numpy as np import numpy as np
from pydantic.fields import Field from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
@ -64,27 +65,33 @@ class PromptsFromFileInvocation(BaseInvocation):
type: Literal['prompt_from_file'] = 'prompt_from_file' type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs # Inputs
filename: str = Field(default=None, description="Filename of prompt text file") file_path: str = Field(description="Path to prompt text file")
pre_prompt: str = Field(default=None, description="Add to start of prompt") pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: str = Field(default=None, description="Add to end of prompt") post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file start start from") start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=0, ge=0, description="Max lines to read from file (0=all)") max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
#fmt: on #fmt: on
def promptsFromFile(self, filename: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts:int): @validator("file_path")
def file_path_exists(cls, v):
if not exists(v):
raise ValueError("file path not found")
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
prompts = [] prompts = []
start_line -= 1 start_line -= 1
end_line = start_line + max_prompts end_line = start_line + max_prompts
if max_prompts <= 0: if max_prompts <= 0:
end_line = np.iinfo(np.int32).max end_line = np.iinfo(np.int32).max
with open(filename) as f: with open(file_path) as f:
for i, line in enumerate(f): for i, line in enumerate(f):
if i >= start_line and i < end_line: if i >= start_line and i < end_line:
prompts.append(pre_prompt + line.strip() + post_prompt) prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line: if i >= end_line:
break break
return prompts return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput: def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.filename, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts) prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))