diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 299cd2b462..6713a587f2 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -1,7 +1,8 @@ -from typing import Literal +from os.path import exists +from typing import Literal, Optional import numpy as np -from pydantic.fields import Field +from pydantic import Field, validator from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator @@ -64,27 +65,33 @@ class PromptsFromFileInvocation(BaseInvocation): type: Literal['prompt_from_file'] = 'prompt_from_file' # Inputs - filename: str = Field(default=None, description="Filename of prompt text file") - pre_prompt: str = Field(default=None, description="Add to start of prompt") - post_prompt: str = Field(default=None, description="Add to end of prompt") - start_line: int = Field(default=1, ge=1, description="Line in the file start start from") - max_prompts: int = Field(default=0, ge=0, description="Max lines to read from file (0=all)") + file_path: str = Field(description="Path to prompt text file") + pre_prompt: Optional[str] = Field(description="String to prepend to each prompt") + post_prompt: Optional[str] = Field(description="String to append to each prompt") + start_line: int = Field(default=1, ge=1, description="Line in the file to start start from") + max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)") #fmt: on - def promptsFromFile(self, filename: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts:int): + @validator("file_path") + def file_path_exists(cls, v): + if not exists(v): + raise ValueError("file path not found") + return v + + def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int): prompts = [] start_line -= 1 end_line = start_line + max_prompts if max_prompts <= 0: end_line = np.iinfo(np.int32).max - with open(filename) as f: + with open(file_path) as f: for i, line in enumerate(f): if i >= start_line and i < end_line: - prompts.append(pre_prompt + line.strip() + post_prompt) + prompts.append((pre_prompt or '') + line.strip() + (post_prompt or '')) if i >= end_line: break return prompts def invoke(self, context: InvocationContext) -> PromptCollectionOutput: - prompts = self.promptsFromFile(self.filename, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts) + prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))