mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into development
This commit is contained in:
commit
2f93418095
137
README.md
137
README.md
@ -1,21 +1,36 @@
|
||||
<h1 align='center'><b>Stable Diffusion Dream Script</b></h1>
|
||||
<div align="center">
|
||||
|
||||
<p align='center'>
|
||||
<img src="docs/assets/logo.png"/>
|
||||
</p>
|
||||
# Stable Diffusion Dream Script
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/lstein/stable-diffusion/releases"><img src="https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github" alt="release"/></a>
|
||||
<a href="https://github.com/lstein/stable-diffusion/stargazers"><img src="https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github" alt="stars"/></a>
|
||||
<a href="https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion"><img src="https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github" alt="forks"/></a>
|
||||
<br />
|
||||
<a href="https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml"><img src="https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github" alt="CI status on main"/></a>
|
||||
<a href="https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml"><img src="https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github" alt="CI status on dev"/></a>
|
||||
<a href="https://github.com/lstein/stable-diffusion/commits/development"><img src="https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900" alt="last-dev-commit"/></a>
|
||||
<br />
|
||||
<a href="https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen"><img src="https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github" alt="open-issues"/></a>
|
||||
<a href="https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen"><img src="https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github" alt="open-prs"/></a>
|
||||
</p>
|
||||

|
||||
|
||||
[![discord badge]][discord link]
|
||||
|
||||
[![latest release badge]][latest release link] [![github stars badge]][github stars link] [![github forks badge]][github forks link]
|
||||
|
||||
[![CI checks on main badge]][CI checks on main link] [![CI checks on dev badge]][CI checks on dev link] [![latest commit to dev badge]][latest commit to dev link]
|
||||
|
||||
[![github open issues badge]][github open issues link] [![github open prs badge]][github open prs link]
|
||||
|
||||
[CI checks on dev badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/development?label=CI%20status%20on%20dev&cache=900&icon=github
|
||||
[CI checks on dev link]: https://github.com/lstein/stable-diffusion/actions?query=branch%3Adevelopment
|
||||
[CI checks on main badge]: https://flat.badgen.net/github/checks/lstein/stable-diffusion/main?label=CI%20status%20on%20main&cache=900&icon=github
|
||||
[CI checks on main link]: https://github.com/lstein/stable-diffusion/actions/workflows/test-dream-conda.yml
|
||||
[discord badge]: https://flat.badgen.net/discord/members/htRgbc7e?icon=discord
|
||||
[discord link]: https://discord.com/invite/htRgbc7e
|
||||
[github forks badge]: https://flat.badgen.net/github/forks/lstein/stable-diffusion?icon=github
|
||||
[github forks link]: https://useful-forks.github.io/?repo=lstein%2Fstable-diffusion
|
||||
[github open issues badge]: https://flat.badgen.net/github/open-issues/lstein/stable-diffusion?icon=github
|
||||
[github open issues link]: https://github.com/lstein/stable-diffusion/issues?q=is%3Aissue+is%3Aopen
|
||||
[github open prs badge]: https://flat.badgen.net/github/open-prs/lstein/stable-diffusion?icon=github
|
||||
[github open prs link]: https://github.com/lstein/stable-diffusion/pulls?q=is%3Apr+is%3Aopen
|
||||
[github stars badge]: https://flat.badgen.net/github/stars/lstein/stable-diffusion?icon=github
|
||||
[github stars link]: https://github.com/lstein/stable-diffusion/stargazers
|
||||
[latest commit to dev badge]: https://flat.badgen.net/github/last-commit/lstein/stable-diffusion/development?icon=github&color=yellow&label=last%20dev%20commit&cache=900
|
||||
[latest commit to dev link]: https://github.com/lstein/stable-diffusion/commits/development
|
||||
[latest release badge]: https://flat.badgen.net/github/release/lstein/stable-diffusion/development?icon=github
|
||||
[latest release link]: https://github.com/lstein/stable-diffusion/releases
|
||||
</div>
|
||||
|
||||
This is a fork of [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion), the open
|
||||
source text-to-image generator. It provides a streamlined process with various new features and
|
||||
@ -26,7 +41,7 @@ _Note: This fork is rapidly evolving. Please use the
|
||||
[Issues](https://github.com/lstein/stable-diffusion/issues) tab to report bugs and make feature
|
||||
requests. Be sure to use the provided templates. They will help aid diagnose issues faster._
|
||||
|
||||
**Table of Contents**
|
||||
## Table of Contents
|
||||
|
||||
1. [Installation](#installation)
|
||||
2. [Hardware Requirements](#hardware-requirements)
|
||||
@ -38,38 +53,38 @@ requests. Be sure to use the provided templates. They will help aid diagnose iss
|
||||
8. [Support](#support)
|
||||
9. [Further Reading](#further-reading)
|
||||
|
||||
## Installation
|
||||
### Installation
|
||||
|
||||
This fork is supported across multiple platforms. You can find individual installation instructions
|
||||
below.
|
||||
|
||||
- ### [Linux](docs/installation/INSTALL_LINUX.md)
|
||||
- #### [Linux](docs/installation/INSTALL_LINUX.md)
|
||||
|
||||
- ### [Windows](docs/installation/INSTALL_WINDOWS.md)
|
||||
- #### [Windows](docs/installation/INSTALL_WINDOWS.md)
|
||||
|
||||
- ### [Macintosh](docs/installation/INSTALL_MAC.md)
|
||||
- #### [Macintosh](docs/installation/INSTALL_MAC.md)
|
||||
|
||||
## Hardware Requirements
|
||||
### Hardware Requirements
|
||||
|
||||
**System**
|
||||
#### System
|
||||
|
||||
You wil need one of the following:
|
||||
|
||||
- An NVIDIA-based graphics card with 4 GB or more VRAM memory.
|
||||
- An Apple computer with an M1 chip.
|
||||
|
||||
**Memory**
|
||||
#### Memory
|
||||
|
||||
- At least 12 GB Main Memory RAM.
|
||||
|
||||
**Disk**
|
||||
#### Disk
|
||||
|
||||
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
||||
|
||||
**Note**
|
||||
|
||||
If you are have a Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
|
||||
full-precision mode as shown below.
|
||||
> Note
|
||||
>
|
||||
> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
|
||||
> full-precision mode as shown below.
|
||||
|
||||
Similarly, specify full-precision mode on Apple M1 hardware.
|
||||
|
||||
@ -79,45 +94,31 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
|
||||
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision
|
||||
```
|
||||
|
||||
## Features
|
||||
### Features
|
||||
|
||||
### Major Features
|
||||
#### Major Features
|
||||
|
||||
- #### [Interactive Command Line Interface](docs/features/CLI.md)
|
||||
- [Interactive Command Line Interface](docs/features/CLI.md)
|
||||
- [Image To Image](docs/features/IMG2IMG.md)
|
||||
- [Inpainting Support](docs/features/INPAINTING.md)
|
||||
- [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
|
||||
- [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
|
||||
- [Google Colab](docs/features/OTHER.md#google-colab)
|
||||
- [Web Server](docs/features/WEB.md)
|
||||
- [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file)
|
||||
- [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
|
||||
- [Weighted Prompts](docs/features/OTHER.md#weighted-prompts)
|
||||
- [Thresholding and Perlin Noise Initialization Options](/docs/features/OTHER.md#thresholding-and-perlin-noise-initialization-options)
|
||||
- [Variations](docs/features/VARIATIONS.md)
|
||||
- [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
|
||||
- [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
|
||||
|
||||
- #### [Image To Image](docs/features/IMG2IMG.md)
|
||||
#### Other Features
|
||||
|
||||
- #### [Inpainting Support](docs/features/INPAINTING.md)
|
||||
- [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
|
||||
- [Preload Models](docs/features/OTHER.md#preload-models)
|
||||
|
||||
- #### [GFPGAN and Real-ESRGAN Support](docs/features/UPSCALE.md)
|
||||
|
||||
- #### [Seamless Tiling](docs/features/OTHER.md#seamless-tiling)
|
||||
|
||||
- #### [Google Colab](docs/features/OTHER.md#google-colab)
|
||||
|
||||
- #### [Web Server](docs/features/WEB.md)
|
||||
|
||||
- #### [Reading Prompts From File](docs/features/OTHER.md#reading-prompts-from-a-file)
|
||||
|
||||
- #### [Shortcut: Reusing Seeds](docs/features/OTHER.md#shortcuts-reusing-seeds)
|
||||
|
||||
- #### [Weighted Prompts](docs/features/OTHER.md#weighted-prompts)
|
||||
|
||||
- #### [Thresholding and Perlin Noise Initialization Options](/docs/features/OTHER.md#thresholding-and-perlin-noise-initialization-options)
|
||||
|
||||
- #### [Variations](docs/features/VARIATIONS.md)
|
||||
|
||||
- #### [Personalizing Text-to-Image Generation](docs/features/TEXTUAL_INVERSION.md)
|
||||
|
||||
- #### [Simplified API for text to image generation](docs/features/OTHER.md#simplified-api)
|
||||
|
||||
### Other Features
|
||||
|
||||
- #### [Creating Transparent Regions for Inpainting](docs/features/INPAINTING.md#creating-transparent-regions-for-inpainting)
|
||||
|
||||
- #### [Preload Models](docs/features/OTHER.md#preload-models)
|
||||
|
||||
## Latest Changes
|
||||
### Latest Changes
|
||||
|
||||
- v1.14 (11 September 2022)
|
||||
|
||||
@ -149,12 +150,12 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
|
||||
|
||||
For older changelogs, please visit the **[CHANGELOG](docs/features/CHANGELOG.md)**.
|
||||
|
||||
## Troubleshooting
|
||||
### Troubleshooting
|
||||
|
||||
Please check out our **[Q&A](docs/help/TROUBLESHOOT.md)** to get solutions for common installation
|
||||
problems and other issues.
|
||||
|
||||
## Contributing
|
||||
### Contributing
|
||||
|
||||
Anyone who wishes to contribute to this project, whether documentation, features, bug fixes, code
|
||||
cleanup, testing, or code reviews, is very much encouraged to do so. If you are unfamiliar with how
|
||||
@ -166,13 +167,13 @@ important thing is to **make your pull request against the "development" branch*
|
||||
"main". This will help keep public breakage to a minimum and will allow you to propose more radical
|
||||
changes.
|
||||
|
||||
## Contributors
|
||||
### Contributors
|
||||
|
||||
This fork is a combined effort of various people from across the world.
|
||||
[Check out the list of all these amazing people](docs/other/CONTRIBUTORS.md). We thank them for
|
||||
their time, hard work and effort.
|
||||
|
||||
## Support
|
||||
### Support
|
||||
|
||||
For support, please use this repository's GitHub Issues tracking service. Feel free to send me an
|
||||
email if you use and like the script.
|
||||
@ -180,7 +181,7 @@ email if you use and like the script.
|
||||
Original portions of the software are Copyright (c) 2020
|
||||
[Lincoln D. Stein](https://github.com/lstein)
|
||||
|
||||
## Further Reading
|
||||
### Further Reading
|
||||
|
||||
Please see the original README for more information on this software and underlying algorithm,
|
||||
located in the file [README-CompViz.md](docs/other/README-CompViz.md).
|
||||
|
@ -2,7 +2,10 @@
|
||||
|
||||
The Args class parses both the command line (shell) arguments, as well as the
|
||||
command string passed at the dream> prompt. It serves as the definitive repository
|
||||
of all the arguments used by Generate and their default values.
|
||||
of all the arguments used by Generate and their default values, and implements the
|
||||
preliminary metadata standards discussed here:
|
||||
|
||||
https://github.com/lstein/stable-diffusion/issues/266
|
||||
|
||||
To use:
|
||||
opt = Args()
|
||||
@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example:
|
||||
To add new attributes, edit the _create_arg_parser() and
|
||||
_create_dream_cmd_parser() methods.
|
||||
|
||||
We also export the function build_metadata
|
||||
**Generating and retrieving sd-metadata**
|
||||
|
||||
To generate a dict representing RFC266 metadata:
|
||||
|
||||
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
|
||||
|
||||
This will generate an RFC266 dictionary that can then be turned into a JSON
|
||||
and written to the PNG file. The optional seeds, weights, model_hash and
|
||||
postprocesser arguments are not available to the opt object and so must be
|
||||
provided externally. See how dream.py does it.
|
||||
|
||||
Note that this function was originally called format_metadata() and a wrapper
|
||||
is provided that issues a deprecation notice.
|
||||
|
||||
To retrieve a (series of) opt objects corresponding to the metadata, do this:
|
||||
|
||||
opt_list = metadata_loads(metadata)
|
||||
|
||||
The metadata should be pulled out of the PNG image. pngwriter has a method
|
||||
retrieve_metadata that will do this.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
import shlex
|
||||
import json
|
||||
import hashlib
|
||||
@ -540,17 +565,20 @@ class Args(object):
|
||||
)
|
||||
return parser
|
||||
|
||||
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
|
||||
# it does not write all the required top-level metadata, writes too much image
|
||||
# data, and doesn't support grids yet. But you gotta start somewhere, no?
|
||||
def format_metadata(opt,
|
||||
seeds=[],
|
||||
weights=None,
|
||||
model_hash=None,
|
||||
postprocessing=None):
|
||||
def format_metadata(**kwargs):
|
||||
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
|
||||
return metadata_dumps(kwargs)
|
||||
|
||||
def metadata_dumps(opt,
|
||||
seeds=[],
|
||||
model_hash=None,
|
||||
postprocessing=None):
|
||||
'''
|
||||
Given an Args object, returns a partial implementation of
|
||||
the stable diffusion metadata standard
|
||||
Given an Args object, returns a dict containing the keys and
|
||||
structure of the proposed stable diffusion metadata standard
|
||||
https://github.com/lstein/stable-diffusion/discussions/392
|
||||
This is intended to be turned into JSON and stored in the
|
||||
"sd
|
||||
'''
|
||||
# add some RFC266 fields that are generated internally, and not as
|
||||
# user args
|
||||
@ -598,6 +626,9 @@ def format_metadata(opt,
|
||||
rfc_dict['type'] = 'txt2img'
|
||||
|
||||
images = []
|
||||
if len(seeds)==0 and opt.seed:
|
||||
seeds=[seed]
|
||||
|
||||
for seed in seeds:
|
||||
rfc_dict['seed'] = seed
|
||||
images.append(copy.copy(rfc_dict))
|
||||
@ -611,6 +642,27 @@ def format_metadata(opt,
|
||||
'images' : images,
|
||||
}
|
||||
|
||||
def metadata_loads(metadata):
|
||||
'''
|
||||
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
|
||||
and returns a series of opt objects for each of the images described in the dictionary.
|
||||
'''
|
||||
results = []
|
||||
try:
|
||||
images = metadata['sd-metadata']['images']
|
||||
for image in images:
|
||||
# repack the prompt and variations
|
||||
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
|
||||
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
|
||||
opt = Args()
|
||||
opt._cmd_switches = Namespace(**image)
|
||||
results.append(opt)
|
||||
except KeyError as e:
|
||||
import sys, traceback
|
||||
print('>> badly-formatted metadata',file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return results
|
||||
|
||||
# image can either be a file path on disk or a base64-encoded
|
||||
# representation of the file's contents
|
||||
def calculate_init_img_hash(image_string):
|
||||
|
@ -4,7 +4,7 @@ import copy
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from ldm.dream.args import Args, format_metadata
|
||||
from ldm.dream.args import Args, metadata_dumps
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from threading import Event
|
||||
@ -177,10 +177,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
path = pngwriter.save_image_and_prompt_to_png(
|
||||
image,
|
||||
dream_prompt = formatted_prompt,
|
||||
metadata = format_metadata(iter_opt,
|
||||
seeds = [seed],
|
||||
weights = self.model.weights,
|
||||
model_hash = self.model.model_hash
|
||||
metadata = metadata_dumps(iter_opt,
|
||||
seeds = [seed],
|
||||
model_hash = self.model.model_hash
|
||||
),
|
||||
name = name,
|
||||
)
|
||||
|
@ -90,7 +90,7 @@ class LinearAttention(nn.Module):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
@ -167,101 +167,85 @@ class CrossAttention(nn.Module):
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
def einsum_op_compvis(self, q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
return self.einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
return self.einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
def einsum_op(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = self.einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
@ -3,6 +3,7 @@ import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import silu
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
@ -32,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
@ -122,14 +118,14 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = self.norm1(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(silu(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
@ -368,7 +364,7 @@ class Model(nn.Module):
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = silu(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
@ -402,7 +398,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -499,7 +495,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -611,7 +607,7 @@ class Decoder(nn.Module):
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
@ -649,7 +645,7 @@ class SimpleDecoder(nn.Module):
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
@ -697,7 +693,7 @@ class UpsampleDecoder(nn.Module):
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -873,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
z = silu(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
|
@ -252,12 +252,6 @@ def normalization(channels):
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
@ -8,7 +8,7 @@ import copy
|
||||
import warnings
|
||||
import time
|
||||
import ldm.dream.readline
|
||||
from ldm.dream.args import Args, format_metadata
|
||||
from ldm.dream.args import Args, metadata_dumps
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||
from ldm.dream.image_util import make_grid
|
||||
@ -245,10 +245,9 @@ def main_loop(gen, opt, infile):
|
||||
path = file_writer.save_image_and_prompt_to_png(
|
||||
image = image,
|
||||
dream_prompt = formatted_dream_prompt,
|
||||
metadata = format_metadata(
|
||||
metadata = metadata_dumps(
|
||||
opt,
|
||||
seeds = [seed],
|
||||
weights = gen.weights,
|
||||
model_hash = gen.model_hash,
|
||||
),
|
||||
name = filename,
|
||||
@ -272,7 +271,7 @@ def main_loop(gen, opt, infile):
|
||||
filename = f'{prefix}.{first_seed}.png'
|
||||
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
|
||||
formatted_dream_prompt += f' # {grid_seeds}'
|
||||
metadata = format_metadata(
|
||||
metadata = metadata.dumps(
|
||||
opt,
|
||||
seeds = grid_seeds,
|
||||
weights = gen.weights,
|
||||
|
Loading…
Reference in New Issue
Block a user