fix(nodes): fix revert conflict

This commit is contained in:
psychedelicious 2023-06-13 22:02:01 +10:00
parent e0c998d192
commit 34ebee67b7

View File

@ -1,11 +1,9 @@
import os from typing import Literal
from typing import Literal, Optional
from pydantic.fields import Field from pydantic.fields import Field
from pyparsing import ParseException
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts import RandomPromptGenerator, CombinatorialPromptGenerator from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput): class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt""" """Base class for invocations that output a prompt"""
@ -24,43 +22,36 @@ class PromptOutput(BaseInvocationOutput):
} }
class PromptListOutput(BaseInvocationOutput): class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a list of prompts""" """Base class for invocations that output a collection of prompts"""
# fmt: off # fmt: off
type: Literal["prompt_list"] = "prompt_list" type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompts: list[str] = Field(description="The output prompts") prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompts list") count: int = Field(description="The size of the prompt collection")
# fmt: on # fmt: on
class Config: class Config:
schema_extra = {"required": ["type", "prompts", "count"]} schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation): class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt" type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field( prompt: str = Field(description="The prompt to parse with dynamicprompts")
default=None, description="The prompt to parse with dynamicprompts"
)
max_prompts: int = Field(default=1, description="The number of prompts to generate") max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field( combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator" default=False, description="Whether to use the combinatorial generator"
) )
def invoke(self, context: InvocationContext) -> PromptListOutput: def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
try: if self.combinatorial:
if self.combinatorial: generator = CombinatorialPromptGenerator()
generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) else:
else: generator = RandomPromptGenerator()
generator = RandomPromptGenerator() prompts = generator.generate(self.prompt, num_images=self.max_prompts)
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
except ParseException as e:
warning = f"Invalid dynamic prompt: {e}"
context.services.logger.warn(warning)
return PromptListOutput(prompts=[self.prompt], count=1)
return PromptListOutput(prompts=prompts, count=len(prompts)) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))