Merge branch 'main' into release/invokeai-3-0-beta

This commit is contained in:
Lincoln Stein 2023-07-07 17:45:18 -04:00
commit 657e8031bb
9 changed files with 113 additions and 73 deletions

View File

@ -28,6 +28,10 @@ from .api.routers import sessions, models, images, boards, board_images
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
import torch
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None) app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)

View File

@ -52,6 +52,10 @@ from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
import torch
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
class CliCommand(BaseModel): class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore

View File

@ -32,7 +32,7 @@ def get_noise(
perlin: float = 0.0, perlin: float = 0.0,
): ):
"""Generate noise for a given image size.""" """Generate noise for a given image size."""
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type noise_device_type = "cpu" if use_cpu else device.type
# limit noise to only the diffusion image channels, not the mask channels # limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4) input_channels = min(latent_channels, 4)

View File

@ -570,28 +570,16 @@ class Generator:
device = self.model.device device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels # limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4) input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == "mps": x = torch.randn(
x = torch.randn( [
[ 1,
1, input_channels,
input_channels, height // self.downsampling_factor,
height // self.downsampling_factor, width // self.downsampling_factor,
width // self.downsampling_factor, ],
], dtype=self.torch_dtype(),
dtype=self.torch_dtype(), device=device,
device="cpu", )
).to(device)
else:
x = torch.randn(
[
1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
],
dtype=self.torch_dtype(),
device=device,
)
if self.perlin > 0.0: if self.perlin > 0.0:
perlin_noise = self.get_perlin_noise( perlin_noise = self.get_perlin_noise(
width // self.downsampling_factor, height // self.downsampling_factor width // self.downsampling_factor, height // self.downsampling_factor

View File

@ -88,10 +88,7 @@ class Img2Img(Generator):
def get_noise_like(self, like: torch.Tensor): def get_noise_like(self, like: torch.Tensor):
device = like.device device = like.device
if device.type == "mps": x = torch.randn_like(like, device=device)
x = torch.randn_like(like, device="cpu").to(device)
else:
x = torch.randn_like(like, device=device)
if self.perlin > 0.0: if self.perlin > 0.0:
shape = like.shape shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise( x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(

View File

@ -360,37 +360,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
): ):
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
else: else:
if torch.backends.mps.is_available(): if self.device.type == "cpu" or self.device.type == "mps":
# until pytorch #91617 is fixed, slicing is borked on MPS mem_free = psutil.virtual_memory().free
# https://github.com/pytorch/pytorch/issues/91617 elif self.device.type == "cuda":
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline. mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
pass
else: else:
if self.device.type == "cpu" or self.device.type == "mps": raise ValueError(f"unrecognized device {self.device}")
mem_free = psutil.virtual_memory().free # input tensor of [1, 4, h/8, w/8]
elif self.device.type == "cuda": # output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) bytes_per_element_needed_for_baddbmm_duplication = (
else: latents.element_size() + 4
raise ValueError(f"unrecognized device {self.device}") )
# input tensor of [1, 4, h/8, w/8] max_size_required_for_baddbmm = (
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)] 16
bytes_per_element_needed_for_baddbmm_duplication = ( * latents.size(dim=2)
latents.element_size() + 4 * latents.size(dim=3)
) * latents.size(dim=2)
max_size_required_for_baddbmm = ( * latents.size(dim=3)
16 * bytes_per_element_needed_for_baddbmm_duplication
* latents.size(dim=2) )
* latents.size(dim=3) if max_size_required_for_baddbmm > (
* latents.size(dim=2) mem_free * 3.0 / 4.0
* latents.size(dim=3) ): # 3.3 / 4.0 is from old Invoke code
* bytes_per_element_needed_for_baddbmm_duplication self.enable_attention_slicing(slice_size="max")
) elif torch.backends.mps.is_available():
if max_size_required_for_baddbmm > ( # diffusers recommends always enabling for mps
mem_free * 3.0 / 4.0 self.enable_attention_slicing(slice_size="max")
): # 3.3 / 4.0 is from old Invoke code else:
self.enable_attention_slicing(slice_size="max") self.disable_attention_slicing()
else:
self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass. # overridden method; types match the superclass.
@ -916,20 +913,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
if device.type == "mps": self._model_group.load(self.vae)
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to(CPU_DEVICE)
init_image = init_image.to(CPU_DEVICE)
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to( init_latents = init_latent_dist.sample().to(
dtype=dtype dtype=dtype
) # FIXME: uses torch.randn. make reproducible! ) # FIXME: uses torch.randn. make reproducible!
if device.type == "mps":
self.vae.to(device)
init_latents = init_latents.to(device)
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents return init_latents

View File

@ -248,9 +248,6 @@ class InvokeAIDiffuserComponent:
x_twice, sigma_twice, both_conditionings, **kwargs, x_twice, sigma_twice, both_conditionings, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially( def _apply_standard_conditioning_sequentially(
@ -264,9 +261,6 @@ class InvokeAIDiffuserComponent:
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
# TODO: looks unused # TODO: looks unused

View File

@ -29,6 +29,8 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16" return "float16"
elif device.type == "mps":
return "float16"
return "float32" return "float32"

View File

@ -0,0 +1,63 @@
import torch
if torch.backends.mps.is_available():
torch.empty = torch.zeros
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
if weight is not None:
weight = weight.float()
if bias is not None:
bias = bias.float()
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
_torch_tensor_permute = torch.Tensor.permute
def new_torch_tensor_permute(input, *dims):
result = _torch_tensor_permute(input, *dims)
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
_torch_lerp = torch.lerp
def new_torch_lerp(input, end, weight, *, out=None):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
end = end.float()
if isinstance(weight, torch.Tensor):
weight = weight.float()
if out is not None:
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
else:
out_fp32 = None
result = _torch_lerp(input, end, weight, out=out_fp32)
if out is not None:
out.copy_(out_fp32.half())
del out_fp32
return result.half()
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate