Fix preview, inpaint

This commit is contained in:
Sergey Borisov 2023-08-07 21:27:32 +03:00
parent 2539e26c18
commit 1db2c93f75
4 changed files with 58 additions and 36 deletions

View File

@ -16,7 +16,7 @@ from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput from .image import ImageOutput
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management import ModelPatcher, BaseModelType
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField from .model import UNetField, VaeField
from .compel import ConditioningField from .compel import ConditioningField
@ -140,6 +140,7 @@ class InpaintInvocation(BaseInvocation):
self, self,
context: InvocationContext, context: InvocationContext,
source_node_id: str, source_node_id: str,
base_model: BaseModelType,
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
) -> None: ) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
@ -147,15 +148,16 @@ class InpaintInvocation(BaseInvocation):
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
node=self.dict(), node=self.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
base_model=base_model,
) )
def get_conditioning(self, context, unet): def get_conditioning(self, context, unet):
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
return (uc, c, extra_conditioning_info) return (uc, c, extra_conditioning_info)
@ -225,7 +227,7 @@ class InpaintInvocation(BaseInvocation):
scheduler=scheduler, scheduler=scheduler,
init_image=image, init_image=image,
mask_image=mask, mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id, self.unet.unet.base_model),
**self.dict( **self.dict(
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"} exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually

View File

@ -24,7 +24,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher from ...backend.model_management import ModelPatcher, BaseModelType
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
@ -160,12 +160,14 @@ class TextToLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
source_node_id: str, source_node_id: str,
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None: ) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
node=self.dict(), node=self.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
base_model=base_model,
) )
def get_conditioning_data( def get_conditioning_data(
@ -340,7 +342,7 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
@ -448,7 +450,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:

View File

@ -7,6 +7,7 @@ from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management.models import BaseModelType
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
@ -29,6 +30,7 @@ def stable_diffusion_step_callback(
intermediate_state: PipelineIntermediateState, intermediate_state: PipelineIntermediateState,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
base_model: BaseModelType,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException raise CanceledException
@ -56,23 +58,51 @@ def stable_diffusion_step_callback(
# TODO: only output a preview image when requested # TODO: only output a preview image when requested
# origingally adapted from code by @erucipe and @keturn here: if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
# these updated numbers for v1.5 are from @torridgristle sdxl_smooth_matrix = torch.tensor(
v1_5_latent_rgb_factors = torch.tensor( [
[ # [ 0.0478, 0.1285, 0.0478],
# R G B # [ 0.1285, 0.2948, 0.1285],
[0.3444, 0.1385, 0.0670], # L1 # [ 0.0478, 0.1285, 0.0478],
[0.1247, 0.4027, 0.1494], # L2 [0.0358, 0.0964, 0.0358],
[-0.3192, 0.2513, 0.2103], # L3 [0.0964, 0.4711, 0.0964],
[-0.1307, -0.1874, -0.7445], # L4 [0.0358, 0.0964, 0.0358],
], ],
dtype=sample.dtype, dtype=sample.dtype,
device=sample.device, device=sample.device,
) )
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else:
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size (width, height) = image.size
width *= 8 width *= 8

View File

@ -50,7 +50,6 @@ from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass @dataclass
class PipelineIntermediateState: class PipelineIntermediateState:
run_id: str
step: int step: int
timestep: int timestep: int
latents: torch.Tensor latents: torch.Tensor
@ -407,7 +406,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: Optional[torch.Tensor], noise: Optional[torch.Tensor],
timesteps=None, timesteps=None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
@ -427,7 +425,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
timesteps, timesteps,
conditioning_data, conditioning_data,
noise=noise, noise=noise,
run_id=run_id,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
callback=callback, callback=callback,
@ -441,13 +438,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
noise: Optional[torch.Tensor], noise: Optional[torch.Tensor],
run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
@ -468,7 +462,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t) latents = self.scheduler.add_noise(latents, noise, batched_t)
yield PipelineIntermediateState( yield PipelineIntermediateState(
run_id=run_id,
step=-1, step=-1,
timestep=self.scheduler.config.num_train_timesteps, timestep=self.scheduler.config.num_train_timesteps,
latents=latents, latents=latents,
@ -507,7 +500,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState( yield PipelineIntermediateState(
run_id=run_id,
step=i, step=i,
timestep=int(t), timestep=int(t),
latents=latents, latents=latents,
@ -619,7 +611,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None, noise_func=None,
seed=None, seed=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
@ -645,7 +636,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data, conditioning_data,
strength, strength,
noise, noise,
run_id,
callback, callback,
) )
@ -678,7 +668,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
*, *,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None, noise_func=None,
seed=None, seed=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
@ -737,7 +726,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise, noise=noise,
timesteps=timesteps, timesteps=timesteps,
additional_guidance=guidance, additional_guidance=guidance,
run_id=run_id,
callback=callback, callback=callback,
) )
finally: finally: