mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix preview, inpaint
This commit is contained in:
parent
2539e26c18
commit
1db2c93f75
@ -16,7 +16,7 @@ from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management import ModelPatcher, BaseModelType
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
@ -140,6 +140,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
@ -147,15 +148,16 @@ class InpaintInvocation(BaseInvocation):
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context, unet):
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@ -225,7 +227,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id, self.unet.unet.base_model),
|
||||
**self.dict(
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
|
@ -24,7 +24,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.model_management import ModelPatcher, BaseModelType
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
@ -160,12 +160,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
@ -340,7 +342,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
@ -379,7 +381,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
|
||||
num_inference_steps, timesteps = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
@ -448,7 +450,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
|
@ -7,6 +7,7 @@ from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
@ -29,6 +30,7 @@ def stable_diffusion_step_callback(
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
@ -56,23 +58,51 @@ def stable_diffusion_step_callback(
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
# [ 0.1285, 0.2948, 0.1285],
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
|
@ -50,7 +50,6 @@ from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
run_id: str
|
||||
step: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
@ -407,7 +406,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps=None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
@ -427,7 +425,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
noise=noise,
|
||||
run_id=run_id,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
callback=callback,
|
||||
@ -441,13 +438,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
@ -468,7 +462,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=-1,
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
@ -507,7 +500,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=i,
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
@ -619,7 +611,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
seed=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
@ -645,7 +636,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
strength,
|
||||
noise,
|
||||
run_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
@ -678,7 +668,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
seed=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
@ -737,7 +726,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
additional_guidance=guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
|
Loading…
Reference in New Issue
Block a user