fix(nodes): fix revert conflict

This commit is contained in:
psychedelicious 2023-06-13 22:02:01 +10:00
parent e0c998d192
commit 34ebee67b7

View File

@ -1,11 +1,9 @@
import os
from typing import Literal, Optional
from typing import Literal
from pydantic.fields import Field
from pyparsing import ParseException
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts import RandomPromptGenerator, CombinatorialPromptGenerator
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
@ -24,43 +22,36 @@ class PromptOutput(BaseInvocationOutput):
}
class PromptListOutput(BaseInvocationOutput):
"""Base class for invocations that output a list of prompts"""
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_list"] = "prompt_list"
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompts: list[str] = Field(description="The output prompts")
count: int = Field(description="The size of the prompts list")
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompts", "count"]}
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(
default=None, description="The prompt to parse with dynamicprompts"
)
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptListOutput:
try:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
except ParseException as e:
warning = f"Invalid dynamic prompt: {e}"
context.services.logger.warn(warning)
return PromptListOutput(prompts=[self.prompt], count=1)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptListOutput(prompts=prompts, count=len(prompts))
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))