mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' of github.com:invoke-ai/InvokeAI into development
This commit is contained in:
commit
2d84e28d32
@ -640,9 +640,11 @@ class InvokeAIWebServer:
|
|||||||
if generation_parameters['progress_latents']:
|
if generation_parameters['progress_latents']:
|
||||||
image = self.generate.sample_to_lowres_estimated_image(sample)
|
image = self.generate.sample_to_lowres_estimated_image(sample)
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="PNG")
|
image.save(buffered, format="PNG")
|
||||||
img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.getvalue()).decode('UTF-8')
|
img_base64 = "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode('UTF-8')
|
||||||
self.socketio.emit(
|
self.socketio.emit(
|
||||||
"intermediateResult",
|
"intermediateResult",
|
||||||
{
|
{
|
||||||
|
@ -60,6 +60,7 @@ We thank them for all of their time and hard work.
|
|||||||
- [Dmitry T.](https://github.com/ArDiouscuros)
|
- [Dmitry T.](https://github.com/ArDiouscuros)
|
||||||
- [Kent Keirsey](https://github.com/hipsterusername)
|
- [Kent Keirsey](https://github.com/hipsterusername)
|
||||||
- [psychedelicious](https://github.com/psychedelicious)
|
- [psychedelicious](https://github.com/psychedelicious)
|
||||||
|
- [damian0815](https://github.com/damian0815)
|
||||||
|
|
||||||
## **Original CompVis Authors:**
|
## **Original CompVis Authors:**
|
||||||
|
|
||||||
|
@ -119,19 +119,19 @@ class Generator():
|
|||||||
# write an approximate RGB image from latent samples for a single step to PNG
|
# write an approximate RGB image from latent samples for a single step to PNG
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(self,samples):
|
def sample_to_lowres_estimated_image(self,samples):
|
||||||
# adapted from code by @erucipe and @keturn here:
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
|
||||||
# these numbers were determined empirically by @keturn
|
# these updated numbers for v1.5 are from @torridgristle
|
||||||
v1_4_latent_rgb_factors = torch.tensor([
|
v1_5_latent_rgb_factors = torch.tensor([
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.298, 0.207, 0.208], # L1
|
[ 0.3444, 0.1385, 0.0670], # L1
|
||||||
[ 0.187, 0.286, 0.173], # L2
|
[ 0.1247, 0.4027, 0.1494], # L2
|
||||||
[-0.158, 0.189, 0.264], # L3
|
[-0.3192, 0.2513, 0.2103], # L3
|
||||||
[-0.184, -0.271, -0.473], # L4
|
[-0.1307, -0.1874, -0.7445] # L4
|
||||||
], dtype=samples.dtype, device=samples.device)
|
], dtype=samples.dtype, device=samples.device)
|
||||||
|
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ v1_4_latent_rgb_factors
|
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||||
latents_ubyte = (((latent_image + 1) / 2)
|
latents_ubyte = (((latent_image + 1) / 2)
|
||||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
.mul(0xFF) # to 0..255
|
.mul(0xFF) # to 0..255
|
||||||
|
@ -68,6 +68,8 @@ class CrossAttentionControl:
|
|||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
|
cls.inject_attention_function(model)
|
||||||
|
|
||||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
||||||
m.last_attn_slice_mask = None
|
m.last_attn_slice_mask = None
|
||||||
m.last_attn_slice_indices = None
|
m.last_attn_slice_indices = None
|
||||||
@ -76,8 +78,6 @@ class CrossAttentionControl:
|
|||||||
m.last_attn_slice_mask = mask.to(device)
|
m.last_attn_slice_mask = mask.to(device)
|
||||||
m.last_attn_slice_indices = indices.to(device)
|
m.last_attn_slice_indices = indices.to(device)
|
||||||
|
|
||||||
cls.inject_attention_function(model)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(Enum):
|
class CrossAttentionType(Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
@ -151,72 +151,33 @@ class CrossAttentionControl:
|
|||||||
#else:
|
#else:
|
||||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||||
|
|
||||||
|
|
||||||
if self.use_last_attn_slice:
|
if self.use_last_attn_slice:
|
||||||
this_attn_slice = attn_slice
|
|
||||||
if self.last_attn_slice_mask is not None:
|
|
||||||
# indices and mask operate on dim=2, no need to slice
|
|
||||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
|
||||||
base_attn_slice_mask = self.last_attn_slice_mask
|
|
||||||
if dim is None:
|
if dim is None:
|
||||||
base_attn_slice = base_attn_slice_full
|
last_attn_slice = self.last_attn_slice
|
||||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
||||||
elif dim == 0:
|
|
||||||
base_attn_slice = base_attn_slice_full[start:end]
|
|
||||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
||||||
elif dim == 1:
|
|
||||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
|
||||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
|
||||||
|
|
||||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
|
||||||
base_attn_slice * base_attn_slice_mask
|
|
||||||
else:
|
|
||||||
if dim is None:
|
|
||||||
attn_slice = self.last_attn_slice
|
|
||||||
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
# print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||||
elif dim == 0:
|
else:
|
||||||
attn_slice = self.last_attn_slice[start:end]
|
last_attn_slice = self.last_attn_slice[offset]
|
||||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
|
||||||
elif dim == 1:
|
if self.last_attn_slice_mask is None:
|
||||||
attn_slice = self.last_attn_slice[:, start:end]
|
# just use everything
|
||||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
attn_slice = last_attn_slice
|
||||||
|
else:
|
||||||
|
last_attn_slice_mask = self.last_attn_slice_mask
|
||||||
|
remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices)
|
||||||
|
|
||||||
|
this_attn_slice = attn_slice
|
||||||
|
this_attn_slice_mask = 1 - last_attn_slice_mask
|
||||||
|
attn_slice = this_attn_slice * this_attn_slice_mask + \
|
||||||
|
remapped_last_attn_slice * last_attn_slice_mask
|
||||||
|
|
||||||
if self.save_last_attn_slice:
|
if self.save_last_attn_slice:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
self.last_attn_slice = attn_slice
|
self.last_attn_slice = attn_slice
|
||||||
elif dim == 0:
|
|
||||||
# dynamically grow last_attn_slice if needed
|
|
||||||
if self.last_attn_slice is None:
|
|
||||||
self.last_attn_slice = attn_slice
|
|
||||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
|
||||||
elif self.last_attn_slice.shape[0] == start:
|
|
||||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
|
||||||
assert(self.last_attn_slice.shape[0] == end)
|
|
||||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
|
||||||
else:
|
else:
|
||||||
# no need to grow
|
|
||||||
self.last_attn_slice[start:end] = attn_slice
|
|
||||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
|
||||||
|
|
||||||
elif dim == 1:
|
|
||||||
# dynamically grow last_attn_slice if needed
|
|
||||||
if self.last_attn_slice is None:
|
if self.last_attn_slice is None:
|
||||||
self.last_attn_slice = attn_slice
|
self.last_attn_slice = { offset: attn_slice }
|
||||||
elif self.last_attn_slice.shape[1] == start:
|
|
||||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
|
||||||
assert(self.last_attn_slice.shape[1] == end)
|
|
||||||
else:
|
else:
|
||||||
# no need to grow
|
self.last_attn_slice[offset] = attn_slice
|
||||||
self.last_attn_slice[:, start:end] = attn_slice
|
|
||||||
|
|
||||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
||||||
if dim is None:
|
|
||||||
weights = self.last_attn_slice_weights
|
|
||||||
elif dim == 0:
|
|
||||||
weights = self.last_attn_slice_weights[start:end]
|
|
||||||
elif dim == 1:
|
|
||||||
weights = self.last_attn_slice_weights[:, start:end]
|
|
||||||
attn_slice = attn_slice * weights
|
|
||||||
|
|
||||||
return attn_slice
|
return attn_slice
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user