apply changes as suggested @psychedelicious in PR comments.

- filename -> file_path
- pre and post prompt changed to optional
- clearer pre and post prompt descriptions
- handle pre and post prompt passed as None
- max_prompts defaults to 1 isted of 0 to avoid accidentally processing large prompt files with it set to 0 when adding a new node.
This commit is contained in:
skunkworxdark 2023-07-16 14:58:55 +01:00 committed by Kent Keirsey
parent 956011066d
commit b1008af696

View File

@ -1,7 +1,8 @@
from typing import Literal
from os.path import exists
from typing import Literal, Optional
import numpy as np
from pydantic.fields import Field
from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
@ -64,27 +65,33 @@ class PromptsFromFileInvocation(BaseInvocation):
type: Literal['prompt_from_file'] = 'prompt_from_file'
# Inputs
filename: str = Field(default=None, description="Filename of prompt text file")
pre_prompt: str = Field(default=None, description="Add to start of prompt")
post_prompt: str = Field(default=None, description="Add to end of prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file start start from")
max_prompts: int = Field(default=0, ge=0, description="Max lines to read from file (0=all)")
file_path: str = Field(description="Path to prompt text file")
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
#fmt: on
def promptsFromFile(self, filename: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts:int):
@validator("file_path")
def file_path_exists(cls, v):
if not exists(v):
raise ValueError("file path not found")
return v
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
prompts = []
start_line -= 1
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(filename) as f:
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append(pre_prompt + line.strip() + post_prompt)
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(self.filename, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))