mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): cleanup unused params, seed generation
This commit is contained in:
parent
5457c7f069
commit
a1079e455a
@ -3,12 +3,12 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.random
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationConfig,
|
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
)
|
)
|
||||||
@ -52,9 +52,9 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
seed: Optional[int] = Field(
|
seed: Optional[int] = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=np.iinfo(np.int32).max,
|
le=SEED_MAX,
|
||||||
description="The seed for the RNG",
|
description="The seed for the RNG (omit for random)",
|
||||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput, build_image_output
|
||||||
@ -46,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
# fmt: off
|
# fmt: off
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
seed: Optional[int] = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
|
@ -54,7 +54,6 @@ def build_image_output(
|
|||||||
image=image_field,
|
image=image_field,
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
mode=image.mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
@ -104,17 +105,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def random_seed():
|
|
||||||
return random.randint(0, np.iinfo(np.uint32).max)
|
|
||||||
|
|
||||||
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
seed: Optional[int] = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
@ -152,10 +149,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
||||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -262,6 +256,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
|
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -273,10 +271,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
|
||||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
@ -8,6 +8,6 @@ def choose_model(model_manager: ModelManager, model_name: str):
|
|||||||
model = model_manager.get_model(model_name)
|
model = model_manager.get_model(model_name)
|
||||||
else:
|
else:
|
||||||
model = model_manager.get_model()
|
model = model_manager.get_model()
|
||||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
def get_timestamp():
|
||||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
SEED_MAX = np.iinfo(np.int32).max
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_seed():
|
||||||
|
return np.random.randint(0, SEED_MAX)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user