diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 2b72c808e4..c6b85d2bd6 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -592,54 +592,3 @@ class InvokeAIDiffuserComponent: self.last_percent_through = percent_through return latents.to(device=dev) - - # todo: make this work - @classmethod - def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) # aka sigmas - - deltas = None - uncond_latents = None - weighted_cond_list = ( - c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)] - ) - - # below is fugly omg - conditionings = [uc] + [c for c, weight in weighted_cond_list] - weights = [1] + [weight for c, weight in weighted_cond_list] - chunk_count = math.ceil(len(conditionings) / 2) - deltas = None - for chunk_index in range(chunk_count): - offset = chunk_index * 2 - chunk_size = min(2, len(conditionings) - offset) - - if chunk_size == 1: - c_in = conditionings[offset] - latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) - latents_b = None - else: - c_in = torch.cat(conditionings[offset : offset + 2]) - latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) - - # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining - if chunk_index == 0: - uncond_latents = latents_a - deltas = latents_b - uncond_latents - else: - deltas = torch.cat((deltas, latents_a - uncond_latents)) - if latents_b is not None: - deltas = torch.cat((deltas, latents_b - uncond_latents)) - - # merge the weighted deltas together into a single merged delta - per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) - normalize = False - if normalize: - per_delta_weights /= torch.sum(per_delta_weights) - reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) - deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) - - # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) - # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) - - return uncond_latents + deltas_merged * global_guidance_scale