From 34ebee67b7e095235aa81ee33885085976f7dc34 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Jun 2023 22:02:01 +1000 Subject: [PATCH] fix(nodes): fix revert conflict --- invokeai/app/invocations/prompt.py | 43 ++++++++++++------------------ 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index fd9e08912d..9af87e1ed4 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -1,11 +1,9 @@ -import os -from typing import Literal, Optional +from typing import Literal from pydantic.fields import Field -from pyparsing import ParseException from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from dynamicprompts import RandomPromptGenerator, CombinatorialPromptGenerator +from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator class PromptOutput(BaseInvocationOutput): """Base class for invocations that output a prompt""" @@ -24,43 +22,36 @@ class PromptOutput(BaseInvocationOutput): } -class PromptListOutput(BaseInvocationOutput): - """Base class for invocations that output a list of prompts""" +class PromptCollectionOutput(BaseInvocationOutput): + """Base class for invocations that output a collection of prompts""" # fmt: off - type: Literal["prompt_list"] = "prompt_list" + type: Literal["prompt_collection_output"] = "prompt_collection_output" - prompts: list[str] = Field(description="The output prompts") - count: int = Field(description="The size of the prompts list") + prompt_collection: list[str] = Field(description="The output prompt collection") + count: int = Field(description="The size of the prompt collection") # fmt: on class Config: - schema_extra = {"required": ["type", "prompts", "count"]} + schema_extra = {"required": ["type", "prompt_collection", "count"]} class DynamicPromptInvocation(BaseInvocation): """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" type: Literal["dynamic_prompt"] = "dynamic_prompt" - prompt: str = Field( - default=None, description="The prompt to parse with dynamicprompts" - ) + prompt: str = Field(description="The prompt to parse with dynamicprompts") max_prompts: int = Field(default=1, description="The number of prompts to generate") combinatorial: bool = Field( default=False, description="Whether to use the combinatorial generator" ) - def invoke(self, context: InvocationContext) -> PromptListOutput: - try: - if self.combinatorial: - generator = CombinatorialPromptGenerator() - prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) - else: - generator = RandomPromptGenerator() - prompts = generator.generate(self.prompt, num_images=self.max_prompts) - except ParseException as e: - warning = f"Invalid dynamic prompt: {e}" - context.services.logger.warn(warning) - return PromptListOutput(prompts=[self.prompt], count=1) + def invoke(self, context: InvocationContext) -> PromptCollectionOutput: + if self.combinatorial: + generator = CombinatorialPromptGenerator() + prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) + else: + generator = RandomPromptGenerator() + prompts = generator.generate(self.prompt, num_images=self.max_prompts) - return PromptListOutput(prompts=prompts, count=len(prompts)) + return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))