mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Removed duplicate fix_func for MPS
This commit is contained in:
parent
2e14ba8716
commit
8254ca9492
@ -58,24 +58,6 @@ torch.multinomial = fix_func(torch.multinomial)
|
|||||||
# this is fallback model in case no default is defined
|
# this is fallback model in case no default is defined
|
||||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||||
|
|
||||||
def fix_func(orig):
|
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"]="cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
Example Usage:
|
Example Usage:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user