mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
use 🧨diffusers model (#1583)
* initial commit of DiffusionPipeline class * spike: proof of concept using diffusers for txt2img * doc: type hints for Generator * refactor(model_cache): factor out load_ckpt * model_cache: add ability to load a diffusers model pipeline and update associated things in Generate & Generator to not instantly fail when that happens * model_cache: fix model default image dimensions * txt2img: support switching diffusers schedulers * diffusers: let the scheduler do its scaling of the initial latents Remove IPNDM scheduler; it is not behaving. * web server: update image_progress callback for diffusers data * diffusers: restore prompt weighting feature * diffusers: fix set-sampler error following model switch * diffusers: use InvokeAIDiffuserComponent for conditioning * cross_attention_control: stub (no-op) implementations for diffusers * model_cache: let offload_model work with DiffusionPipeline, sorta. * models.yaml.example: add diffusers-format model, set as default * test-invoke-conda: use diffusers-format model test-invoke-conda: put huggingface-token where the library can use it * environment-mac: upgrade to diffusers 0.7 (from 0.6) this was already done for linux; mac must have been lost in the merge. * preload_models: explicitly load diffusers models In non-interactive mode too, as long as you're logged in. * fix(model_cache): don't check `model.config` in diffusers format clean-up from recent merge. * diffusers integration: support img2img * dev: upgrade to diffusers 0.8 (from 0.7.1) We get to remove some code by using methods that were factored out in the base class. * refactor: remove backported img2img.get_timesteps now that we can use it directly from diffusers 0.8.1 * ci: use diffusers model * dev: upgrade to diffusers 0.9 (from 0.8.1) * lint: correct annotations for Python 3.9. * lint: correct AttributeError.name reference for Python 3.9. * CI: prefer diffusers-1.4 because it no longer requires a token The RunwayML models still do. * build: there's yet another place to update requirements? * configure: try to download models even without token Models in the CompVis and stabilityai repos no longer require them. (But runwayml still does.) * configure: add troubleshooting info for config-not-found * fix(configure): prepend root to config path * fix(configure): remove second `default: true` from models example * CI: simplify test-on-push logic now that we don't need secrets The "test on push but only in forks" logic was only necessary when tests didn't work for PRs-from-forks. * create an embedding_manager for diffusers * internal: avoid importing diffusers DummyObject see https://github.com/huggingface/diffusers/issues/1479 * fix "config attributes…not expected" diffusers warnings. * fix deprecated scheduler construction * work around an apparent MPS torch bug that causes conditioning to have no effect * 🚧 post-rebase repair * preliminary support for outpainting (no masking yet) * monkey-patch diffusers.attention and use Invoke lowvram code * add always_use_cpu arg to bypass MPS * add cross-attention control support to diffusers (fails on MPS) For unknown reasons MPS produces garbage output with .swap(). Use --always_use_cpu arg to invoke.py for now to test this code on MPS. * diffusers support for the inpainting model * fix debug_image to not crash with non-RGB images. * inpainting for the normal model [WIP] This seems to be performing well until the LAST STEP, at which point it dissolves to confetti. * fix off-by-one bug in cross-attention-control (#1774) prompt token sequences begin with a "beginning-of-sequence" marker <bos> and end with a repeated "end-of-sequence" marker <eos> - to make a default prompt length of <bos> + 75 prompt tokens + <eos>. the .swap() code was failing to take the column for <bos> at index 0 into account. the changes here do that, and also add extra handling for a single <eos> (which may be redundant but which is included for completeness). based on my understanding and some assumptions about how this all works, the reason .swap() nevertheless seemed to do the right thing, to some extent, is because over multiple steps the conditioning process in Stable Diffusion operates as a feedback loop. a change to token n-1 has flow-on effects to how the [1x4x64x64] latent tensor is modified by all the tokens after it, - and as the next step is processed, all the tokens before it as well. intuitively, a token's conditioning effects "echo" throughout the whole length of the prompt. so even though the token at n-1 was being edited when what the user actually wanted was to edit the token at n, it nevertheless still had some non-negligible effect, in roughly the right direction, often enough that it seemed like it was working properly. * refactor common CrossAttention stuff into a mixin so that the old ldm code can still work if necessary * inpainting for the normal model. I think it works this time. * diffusers: reset num_vectors_per_token sync with44a0055571
* diffusers: txt2img2img (hires_fix) with so much slicing and dicing of pipeline methods to stitch them together * refactor(diffusers): reduce some code duplication amongst the different tasks * fixup! refactor(diffusers): reduce some code duplication amongst the different tasks * diffusers: enable DPMSolver++ scheduler * diffusers: upgrade to diffusers 0.10, add Heun scheduler * diffusers(ModelCache): stopgap to make from_cpu compatible with diffusers * CI: default to diffusers-1.5 now that runwayml token requirement is gone * diffusers: update to 0.10 (and transformers to 4.25) * diffusers: use xformers when available diffusers no longer auto-enables this as of 0.10.2. * diffusers: make masked img2img behave better with multi-step schedulers re-randomizing the noise each step was confusing them. * diffusers: work more better with more models. fixed relative path problem with local models. fixed models on hub not always having a `fp16` branch. * diffusers: stopgap fix for attention_maps_callback crash after recent merge * fixup import merge conflicts correction for061c5369a2
* test: add tests/inpainting inputs for masked img2img * diffusers(AddsMaskedGuidance): partial fix for k-schedulers Prevents them from crashing, but results are still hot garbage. * fix --safety_checker arg parsing and add note to diffusers loader about where safety checker gets called * generate: fix import error * CI: don't try to read the old init location * diffusers: support loading an alternate VAE * CI: remove sh-syntax if-statement so it doesn't crash powershell * CI: fold strings in yaml because backslash is not line-continuation in powershell * attention maps callback stuff for diffusers * build: fix syntax error in environment-mac * diffusers: add INITIAL_MODELS with diffusers-compatible repos * re-enable the embedding manager; closes #1778 * Squashed commit of the following: commit e4a956abc37fcb5cf188388b76b617bc5c8fda7d Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 15:43:07 2022 +0100 import new load handling from EmbeddingManager and cleanup commit c4abe91a5ba0d415b45bf734068385668b7a66e6 Merge: 032e856e 1efc6397 Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 15:09:53 2022 +0100 Merge branch 'feature_textual_inversion_mgr' into dev/diffusers_with_textual_inversion_manager commit 032e856eefb3bbc39534f5daafd25764bcfcef8b Merge: 8b4f0fe9bc515e24
Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 15:08:01 2022 +0100 Merge remote-tracking branch 'upstream/dev/diffusers' into dev/diffusers_with_textual_inversion_manager commit 1efc6397fc6e61c1aff4b0258b93089d61de5955 Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 15:04:28 2022 +0100 cleanup and add performance notes commit e400f804ac471a0ca2ba432fd658778b20c7bdab Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 14:45:07 2022 +0100 fix bug and update unit tests commit deb9ae0ae1016750e93ce8275734061f7285a231 Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 14:28:29 2022 +0100 textual inversion manager seems to work commit 162e02505dec777e91a983c4d0fb52e950d25ff0 Merge: cbad4583 12769b3d Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 11:58:03 2022 +0100 Merge branch 'main' into feature_textual_inversion_mgr commit cbad45836c6aace6871a90f2621a953f49433131 Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 11:54:10 2022 +0100 use position embeddings commit 070344c69b0e0db340a183857d0a787b348681d3 Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 11:53:47 2022 +0100 Don't crash CLI on exceptions commit b035ac8c6772dfd9ba41b8eeb9103181cda028f8 Author: Damian Stewart <d@damianstewart.com> Date: Sun Dec 18 11:11:55 2022 +0100 add missing position_embeddings commit 12769b3d3562ef71e0f54946b532ad077e10043c Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 13:33:25 2022 +0100 debugging why it don't work commit bafb7215eabe1515ca5e8388fd3bb2f3ac5362cf Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 13:21:33 2022 +0100 debugging why it don't work commit664a6e9e14
Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 12:48:38 2022 +0100 use TextualInversionManager in place of embeddings (wip, doesn't work) commit 8b4f0fe9d6e4e2643b36dfa27864294785d7ba4e Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 12:48:38 2022 +0100 use TextualInversionManager in place of embeddings (wip, doesn't work) commit ffbe1ab11163ba712e353d89404e301d0e0c6cdf Merge:6e4dad60
023df37e
Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 02:37:31 2022 +0100 Merge branch 'feature_textual_inversion_mgr' into dev/diffusers commit023df37eff
Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 02:36:54 2022 +0100 cleanup commit05fac594ea
Author: Damian Stewart <d@damianstewart.com> Date: Fri Dec 16 02:07:49 2022 +0100 tweak error checking commit009f32ed39
Author: damian <null@damianstewart.com> Date: Thu Dec 15 21:29:47 2022 +0100 unit tests passing for embeddings with vector length >1 commitbeb1b08d9a
Author: Damian Stewart <d@damianstewart.com> Date: Thu Dec 15 13:39:09 2022 +0100 more explicit equality tests when overwriting commit44d8a5a7c8
Author: Damian Stewart <d@damianstewart.com> Date: Thu Dec 15 13:30:13 2022 +0100 wip textual inversion manager (unit tests passing for 1v embedding overwriting) commit417c2b57d9
Author: Damian Stewart <d@damianstewart.com> Date: Thu Dec 15 12:30:55 2022 +0100 wip textual inversion manager (unit tests passing for base stuff + padding) commit2e80872e3b
Author: Damian Stewart <d@damianstewart.com> Date: Thu Dec 15 10:57:57 2022 +0100 wip new TextualInversionManager * stop using WeightedFrozenCLIPEmbedder * store diffusion models locally - configure_invokeai.py reconfigured to store diffusion models rather than CompVis models - hugging face caching model is used, but cache is set to ~/invokeai/models/repo_id - models.yaml does **NOT** use path, just repo_id - "repo_name" changed to "repo_id" to following hugging face conventions - Models are loaded with full precision pending further work. * allow non-local files during development * path takes priority over repo_id * MVP for model_cache and configure_invokeai - Feature complete (almost) - configure_invokeai.py downloads both .ckpt and diffuser models, along with their VAEs. Both types of download are controlled by a unified INITIAL_MODELS.yaml file. - model_cache can load both type of model and switches back and forth in CPU. No memory leaks detected TO DO: 1. I have not yet turned on the LocalOnly flag for diffuser models, so the code will check the Hugging Face repo for updates before using the locally cached models. This will break firewalled systems. I am thinking of putting in a global check for internet connectivity at startup time and setting the LocalOnly flag based on this. It would be good to check updates if there is connectivity. 2. I have not gone completely through INITIAL_MODELS.yaml to check which models are available as diffusers and which are not. So models like PaperCut and VoxelArt may not load properly. The runway and stability models are checked, as well as the Trinart models. 3. Add stanzas for SD 2.0 and 2.1 in INITIAL_MODELS.yaml REMAINING PROBLEMS NOT DIRECTLY RELATED TO MODEL_CACHE: 1. When loading a .ckpt file there are lots of messages like this: Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead. I'm not sure how to address this. 2. The ckpt models ***don't actually run*** due to the lack of special-case support for them in the generator objects. For example, here's the hard crash you get when you run txt2img against the legacy waifu-diffusion-1.3 model: ``` >> An error occurred: Traceback (most recent call last): File "/data/lstein/InvokeAI/ldm/invoke/CLI.py", line 140, in main main_loop(gen, opt) File "/data/lstein/InvokeAI/ldm/invoke/CLI.py", line 371, in main_loop gen.prompt2image( File "/data/lstein/InvokeAI/ldm/generate.py", line 496, in prompt2image results = generator.generate( File "/data/lstein/InvokeAI/ldm/invoke/generator/base.py", line 108, in generate image = make_image(x_T) File "/data/lstein/InvokeAI/ldm/invoke/generator/txt2img.py", line 33, in make_image pipeline_output = pipeline.image_from_embeddings( File "/home/lstein/invokeai/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1265, in __getattr__ raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'LatentDiffusion' object has no attribute 'image_from_embeddings' ``` 3. The inpainting diffusion model isn't working. Here's the output of "banana sushi" when inpainting-1.5 is loaded: ``` Traceback (most recent call last): File "/data/lstein/InvokeAI/ldm/generate.py", line 496, in prompt2image results = generator.generate( File "/data/lstein/InvokeAI/ldm/invoke/generator/base.py", line 108, in generate image = make_image(x_T) File "/data/lstein/InvokeAI/ldm/invoke/generator/txt2img.py", line 33, in make_image pipeline_output = pipeline.image_from_embeddings( File "/data/lstein/InvokeAI/ldm/invoke/generator/diffusers_pipeline.py", line 301, in image_from_embeddings result_latents, result_attention_map_saver = self.latents_from_embeddings( File "/data/lstein/InvokeAI/ldm/invoke/generator/diffusers_pipeline.py", line 330, in latents_from_embeddings result: PipelineIntermediateState = infer_latents_from_embeddings( File "/data/lstein/InvokeAI/ldm/invoke/generator/diffusers_pipeline.py", line 185, in __call__ for result in self.generator_method(*args, **kwargs): File "/data/lstein/InvokeAI/ldm/invoke/generator/diffusers_pipeline.py", line 367, in generate_latents_from_embeddings step_output = self.step(batched_t, latents, guidance_scale, File "/home/lstein/invokeai/.venv/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/data/lstein/InvokeAI/ldm/invoke/generator/diffusers_pipeline.py", line 409, in step step_output = self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs) File "/home/lstein/invokeai/.venv/lib/python3.9/site-packages/diffusers/schedulers/scheduling_lms_discrete.py", line 223, in step pred_original_sample = sample - sigma * model_output RuntimeError: The size of tensor a (9) must match the size of tensor b (4) at non-singleton dimension 1 ``` * proper support for float32/float16 - configure script now correctly detects user's preference for fp16/32 and downloads the correct diffuser version. If fp16 version not available, falls back to fp32 version. - misc code cleanup and simplification in model_cache * add on-the-fly conversion of .ckpt to diffusers models 1. On-the-fly conversion code can be found in the file ldm/invoke/ckpt_to_diffusers.py. 2. A new !optimize command has been added to the CLI. Should be ported to Web GUI. User experience on the CLI is this: ``` invoke> !optimize /home/lstein/invokeai/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt INFO: Converting legacy weights file /home/lstein/invokeai/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt to optimized diffuser model. This operation will take 30-60s to complete. Success. Optimized model is now located at /home/lstein/tmp/invokeai/models/optimized-ckpts/sd-v1-4 Writing new config file entry for sd-v1-4... >> New configuration: sd-v1-4: description: Optimized version of sd-v1-4 format: diffusers path: /home/lstein/tmp/invokeai/models/optimized-ckpts/sd-v1-4 OK to import [n]? y >> Verifying that new model loads... >> Current VRAM usage: 2.60G >> Offloading stable-diffusion-2.1 to CPU >> Loading diffusers model from /home/lstein/tmp/invokeai/models/optimized-ckpts/sd-v1-4 | Using faster float16 precision You have disabled the safety checker for <class 'ldm.invoke.generator.diffusers_pipeline.StableDiffusionGeneratorPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion \ license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances,\ disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . | training width x height = (512 x 512) >> Model loaded in 3.48s >> Max VRAM used to load the model: 2.17G >> Current VRAM usage:2.17G >> Textual inversions available: >> Setting Sampler to k_lms (LMSDiscreteScheduler) Keep model loaded? [y] ``` * add parallel set of generator files for ckpt legacy generation * generation using legacy ckpt models now working * diffusers: fix missing attention_maps_callback fix for23eb80b404
* associate legacy CrossAttention with .ckpt models * enable autoconvert New --autoconvert CLI option will scan a designated directory for new .ckpt files, convert them into diffuser models, and import them into models.yaml. Works like this: invoke.py --autoconvert /path/to/weights/directory In ModelCache added two new methods: autoconvert_weights(config_path, weights_directory_path, models_directory_path) convert_and_import(ckpt_path, diffuser_path) * diffusers: update to diffusers 0.11 (from 0.10.2) * fix vae loading & width/height calculation * refactor: encapsulate these conditioning data into one container * diffusers: fix some noise-scaling issues by pushing the noise-mixing down to the common function * add support for safetensors and accelerate * set local_files_only when internet unreachable * diffusers: fix error-handling path when model repo has no fp16 branch * fix generatorinpaint error Fixes : "ModuleNotFoundError: No module named 'ldm.invoke.generatorinpaint' https://github.com/invoke-ai/InvokeAI/pull/1583#issuecomment-1363634318 * quench diffuser safety-checker warning * diffusers: support stochastic DDIM eta parameter * fix conda env creation on macos * fix cross-attention with diffusers 0.11 * diffusers: the VAE needs to be tiling as well as the U-Net * diffusers: comment on subfolders * diffusers: embiggen! * diffusers: make model_cache.list_models serializable * diffusers(inpaint): restore scaling functionality * fix requirements clash between numba and numpy 1.24 * diffusers: allow inpainting model to do non-inpainting tasks * start expanding model_cache functionality * add import_ckpt_model() and import_diffuser_model() methods to model_manager - in addition, model_cache.py is now renamed to model_manager.py * allow "recommended" flag to be optional in INITIAL_MODELS.yaml * configure_invokeai now downloads VAE diffusers in advance * rename ModelCache to ModelManager * remove support for `repo_name` in models.yaml * check for and refuse to load embeddings trained on incompatible models * models.yaml.example: s/repo_name/repo_id and remove extra INITIAL_MODELS now that the main one has diffusers models in it. * add MVP textual inversion script * refactor(InvokeAIDiffuserComponent): factor out _combine() * InvokeAIDiffuserComponent: implement threshold * InvokeAIDiffuserComponent: diagnostic logs for threshold ...this does not look right * add a curses-based frontend to textual inversion - not quite working yet - requires npyscreen installed - on windows will also have the windows-curses requirement, but not added to requirements yet * add curses-based interface for textual inversion * fix crash in convert_and_import() - This corrects a "local variable referenced before assignment" error in model_manager.convert_and_import() * potential workaround for no 'state_dict' key error - As reported in https://github.com/huggingface/diffusers/issues/1876 * create TI output dir if needed * Update environment-lin-cuda.yml (#2159) Fixing line 42 to be the proper order to define the transformers requirement: ~= instead of =~ * diffusers: update sampler-to-scheduler mapping based on https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672 * improve user exp for ckt to diffusers conversion - !optimize_models command now operates on an existing ckpt file entry in models.yaml - replaces existing entry, rather than adding a new one - offers to delete the ckpt file after conversion * web: adapt progress callback to deal with old generator or new diffusers pipeline * clean-up model_manager code - add_model() verified to work for .ckpt local paths, .ckpt remote URLs, diffusers local paths, and diffusers repo_ids - convert_and_import() verified to work for local and remove .ckpt files * handle edge cases for import_model() and convert_model() * add support for safetensor .ckpt files * fix name error * code cleanup with pyflake * improve model setting behavior - If the user enters an invalid model name at startup time, will not try to load it, warn, and use default model - CLI UI enhancement: include currently active model in the command line prompt. * update test-invoke-pip.yml - fix model cache path to point to runwayml/stable-diffusion-v1-5 - remove `skip-sd-weights` from configure_invokeai.py args * exclude dev/diffusers from "fail for draft PRs" * disable "fail on PR jobs" * re-add `--skip-sd-weights` since no space * update workflow environments - include `INVOKE_MODEL_RECONFIGURE: '--yes'` * clean up model load failure handling - Allow CLI to run even when no model is defined or loadable. - Inhibit stack trace when model load fails - only show last error - Give user *option* to run configure_invokeai.py when no models successfully load. - Restart invokeai after reconfiguration. * further edge-case handling 1) only one model in models.yaml file, and that model is broken 2) no models in models.yaml 3) models.yaml doesn't exist at all * fix incorrect model status listing - "cached" was not being returned from list_models() - normalize handling of exceptions during model loading: - Passing an invalid model name to generate.set_model() will return a KeyError - All other exceptions are returned as the appropriate Exception * CI: do download weights (if not already cached) * diffusers: fix scheduler loading in offline mode * CI: fix model name (no longer has `diffusers-` prefix) * Update txt2img2img.py (#2256) * fixes to share models with HuggingFace cache system - If HF_HOME environment variable is defined, then all huggingface models are stored in that directory following the standard conventions. - For seamless interoperability, set HF_HOME to ~/.cache/huggingface - If HF_HOME not defined, then models are stored in ~/invokeai/models. This is equivalent to setting HF_HOME to ~/invokeai/models A future commit will add a migration mechanism so that this change doesn't break previous installs. * feat - make model storage compatible with hugging face caching system This commit alters the InvokeAI model directory to be compatible with hugging face, making it easier to share diffusers (and other models) across different programs. - If the HF_HOME environment variable is not set, then models are cached in ~/invokeai/models in a format that is identical to the HuggingFace cache. - If HF_HOME is set, then models are cached wherever HF_HOME points. - To enable sharing with other HuggingFace library clients, set HF_HOME to ~/.cache/huggingface to set the default cache location or to ~/invokeai/models to have huggingface cache inside InvokeAI. * fixes to share models with HuggingFace cache system - If HF_HOME environment variable is defined, then all huggingface models are stored in that directory following the standard conventions. - For seamless interoperability, set HF_HOME to ~/.cache/huggingface - If HF_HOME not defined, then models are stored in ~/invokeai/models. This is equivalent to setting HF_HOME to ~/invokeai/models A future commit will add a migration mechanism so that this change doesn't break previous installs. * fix error "no attribute CkptInpaint" * model_manager.list_models() returns entire model config stanza+status * Initial Draft - Model Manager Diffusers * added hash function to diffusers * implement sha256 hashes on diffusers models * Add Model Manager Support for Diffusers * fix various problems with model manager - in cli import functions, fix not enough values to unpack from _get_name_and_desc() - fix crash when using old-style vae: value with new-style diffuser * rebuild frontend * fix dictconfig-not-serializable issue * fix NoneType' object is not subscriptable crash in model_manager * fix "str has no attribute get" error in model_manager list_models() * Add path and repo_id support for Diffusers Model Manager Also fixes bugs * Fix tooltip IT localization not working * Add Version Number To WebUI * Optimize Model Search * Fix incorrect font on the Model Manager UI * Fix image degradation on merge fixes - [Experimental] This change should effectively fix a couple of things. - Fix image degradation on subsequent merges of the canvas layers. - Fix the slight transparent border that is left behind when filling the bounding box with a color. - Fix the left over line of color when filling a bounding box with color. So far there are no side effects for this. If any, please report. * Add local model filtering for Diffusers / Checkpoints * Go to home on modal close for the Add Modal UI * Styling Fixes * Model Manager Diffusers Localization Update * Add Safe Tensor scanning to Model Manager * Fix model edit form dispatching string values instead of numbers. * Resolve VAE handling / edge cases for supplied repos * defer injecting tokens for textual inversions until they're used for the first time * squash a console warning * implement model migration check * add_model() overwrites previous config rather than merges * fix model config file attribute merging * fix precision handling in textual inversion script * allow ckpt conversion script to work with safetensors .ckpts Applied patch here:beb932c5d1
* fix name "args" is not defined crash in textual_inversion_training * fix a second NameError: name 'args' is not defined crash * fix loading of the safety checker from the global cache dir * add installation step to textual inversion frontend - After a successful training run, the script will copy learned_embeds.bin to a subfolder of the embeddings directory. - User given the option to delete the logs and intermediate checkpoints (which together use 7-8G of space) - If textual inversion training fails, reports the error gracefully. * don't crash out on incompatible embeddings - put try: blocks around places where the system tries to load an embedding which is incompatible with the currently loaded model * add support for checkpoint resuming * textual inversion preferences are saved and restored between sessions - Preferences are stored in a file named text-inversion-training/preferences.conf - Currently the resume-from-checkpoint option is not working correctly. Possible bug in textual_inversion_training.py? * copy learned_embeddings.bin into right location * add front end for diffusers model merging - Front end doesn't do anything yet!!!! - Made change to model name parsing in CLI to support ability to have merged models with the "+" character in their names. * improve inpainting experience - recommend ckpt version of inpainting-1.5 to user - fix get_noise() bug in ckpt version of omnibus.py * update environment*yml * tweak instructions to install HuggingFace token * bump version number * enhance update scripts - update scripts will now fetch new INITIAL_MODELS.yaml so that configure_invokeai.py will know about the diffusers versions. * enhance invoke.sh/invoke.bat launchers - added configure_invokeai.py to menu - menu defaults to browser-based invoke * remove conda workflow (#2321) * fix `token_ids has shape torch.Size([79]) - expected [77]` * update CHANGELOG.md with 2.3.* info - Add information on how formats have changed and the upgrade process. - Add short bug list. Co-authored-by: Damian Stewart <d@damianstewart.com> Co-authored-by: Damian Stewart <null@damianstewart.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com> Co-authored-by: Wybartel-luxmc <37852506+Wybartel-luxmc@users.noreply.github.com> Co-authored-by: mauwii <Mauwii@outlook.de> Co-authored-by: mickr777 <115216705+mickr777@users.noreply.github.com> Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Co-authored-by: Eugene Brodsky <ebr@users.noreply.github.com> Co-authored-by: Matthias Wild <40327258+mauwii@users.noreply.github.com>
This commit is contained in:
parent
c855d2a350
commit
6fdbc1978d
161
.github/workflows/test-invoke-conda.yml
vendored
161
.github/workflows/test-invoke-conda.yml
vendored
@ -1,161 +0,0 @@
|
|||||||
name: Test invoke.py
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
types:
|
|
||||||
- 'ready_for_review'
|
|
||||||
- 'opened'
|
|
||||||
- 'synchronize'
|
|
||||||
- 'converted_to_draft'
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
fail_if_pull_request_is_draft:
|
|
||||||
if: github.event.pull_request.draft == true
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- name: Fails in order to indicate that pull request needs to be marked as ready to review and unit tests workflow needs to pass.
|
|
||||||
run: exit 1
|
|
||||||
|
|
||||||
matrix:
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
stable-diffusion-model:
|
|
||||||
- 'stable-diffusion-1.5'
|
|
||||||
environment-yaml:
|
|
||||||
- environment-lin-amd.yml
|
|
||||||
- environment-lin-cuda.yml
|
|
||||||
- environment-mac.yml
|
|
||||||
- environment-win-cuda.yml
|
|
||||||
include:
|
|
||||||
- environment-yaml: environment-lin-amd.yml
|
|
||||||
os: ubuntu-22.04
|
|
||||||
curl-command: curl
|
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
default-shell: bash -l {0}
|
|
||||||
- environment-yaml: environment-lin-cuda.yml
|
|
||||||
os: ubuntu-22.04
|
|
||||||
curl-command: curl
|
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
default-shell: bash -l {0}
|
|
||||||
- environment-yaml: environment-mac.yml
|
|
||||||
os: macos-12
|
|
||||||
curl-command: curl
|
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
default-shell: bash -l {0}
|
|
||||||
- environment-yaml: environment-win-cuda.yml
|
|
||||||
os: windows-2022
|
|
||||||
curl-command: curl.exe
|
|
||||||
github-env: $env:GITHUB_ENV
|
|
||||||
default-shell: pwsh
|
|
||||||
- stable-diffusion-model: stable-diffusion-1.5
|
|
||||||
stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
|
|
||||||
stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1
|
|
||||||
stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt
|
|
||||||
name: ${{ matrix.environment-yaml }} on ${{ matrix.os }}
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
env:
|
|
||||||
CONDA_ENV_NAME: invokeai
|
|
||||||
INVOKEAI_ROOT: '${{ github.workspace }}/invokeai'
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: ${{ matrix.default-shell }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout sources
|
|
||||||
id: checkout-sources
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: create models.yaml from example
|
|
||||||
run: |
|
|
||||||
mkdir -p ${{ env.INVOKEAI_ROOT }}/configs
|
|
||||||
cp configs/models.yaml.example ${{ env.INVOKEAI_ROOT }}/configs/models.yaml
|
|
||||||
|
|
||||||
- name: create environment.yml
|
|
||||||
run: cp "environments-and-requirements/${{ matrix.environment-yaml }}" environment.yml
|
|
||||||
|
|
||||||
- name: Use cached conda packages
|
|
||||||
id: use-cached-conda-packages
|
|
||||||
uses: actions/cache@v3
|
|
||||||
with:
|
|
||||||
path: ~/conda_pkgs_dir
|
|
||||||
key: conda-pkgs-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles(matrix.environment-yaml) }}
|
|
||||||
|
|
||||||
- name: Activate Conda Env
|
|
||||||
id: activate-conda-env
|
|
||||||
uses: conda-incubator/setup-miniconda@v2
|
|
||||||
with:
|
|
||||||
activate-environment: ${{ env.CONDA_ENV_NAME }}
|
|
||||||
environment-file: environment.yml
|
|
||||||
miniconda-version: latest
|
|
||||||
|
|
||||||
- name: set test prompt to main branch validation
|
|
||||||
if: ${{ github.ref == 'refs/heads/main' }}
|
|
||||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: set test prompt to development branch validation
|
|
||||||
if: ${{ github.ref == 'refs/heads/development' }}
|
|
||||||
run: echo "TEST_PROMPTS=tests/dev_prompts.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: set test prompt to Pull Request validation
|
|
||||||
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: Use Cached Stable Diffusion Model
|
|
||||||
id: cache-sd-model
|
|
||||||
uses: actions/cache@v3
|
|
||||||
env:
|
|
||||||
cache-name: cache-${{ matrix.stable-diffusion-model }}
|
|
||||||
with:
|
|
||||||
path: ${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}
|
|
||||||
key: ${{ env.cache-name }}
|
|
||||||
|
|
||||||
- name: Download ${{ matrix.stable-diffusion-model }}
|
|
||||||
id: download-stable-diffusion-model
|
|
||||||
if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }}
|
|
||||||
run: |
|
|
||||||
mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}"
|
|
||||||
${{ matrix.curl-command }} -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" -L ${{ matrix.stable-diffusion-model-url }}
|
|
||||||
|
|
||||||
- name: run configure_invokeai.py
|
|
||||||
id: run-preload-models
|
|
||||||
run: |
|
|
||||||
python scripts/configure_invokeai.py --skip-sd-weights --yes
|
|
||||||
|
|
||||||
- name: cat invokeai.init
|
|
||||||
id: cat-invokeai
|
|
||||||
run: cat ${{ env.INVOKEAI_ROOT }}/invokeai.init
|
|
||||||
|
|
||||||
- name: Run the tests
|
|
||||||
id: run-tests
|
|
||||||
if: matrix.os != 'windows-2022'
|
|
||||||
run: |
|
|
||||||
time python scripts/invoke.py \
|
|
||||||
--no-patchmatch \
|
|
||||||
--no-nsfw_checker \
|
|
||||||
--model ${{ matrix.stable-diffusion-model }} \
|
|
||||||
--from_file ${{ env.TEST_PROMPTS }} \
|
|
||||||
--root="${{ env.INVOKEAI_ROOT }}" \
|
|
||||||
--outdir="${{ env.INVOKEAI_ROOT }}/outputs"
|
|
||||||
|
|
||||||
- name: export conda env
|
|
||||||
id: export-conda-env
|
|
||||||
if: matrix.os != 'windows-2022'
|
|
||||||
run: |
|
|
||||||
mkdir -p outputs/img-samples
|
|
||||||
conda env export --name ${{ env.CONDA_ENV_NAME }} > ${{ env.INVOKEAI_ROOT }}/outputs/environment-${{ runner.os }}-${{ runner.arch }}.yml
|
|
||||||
|
|
||||||
- name: Archive results
|
|
||||||
if: matrix.os != 'windows-2022'
|
|
||||||
id: archive-results
|
|
||||||
uses: actions/upload-artifact@v3
|
|
||||||
with:
|
|
||||||
name: results_${{ matrix.requirements-file }}_${{ matrix.python-version }}
|
|
||||||
path: ${{ env.INVOKEAI_ROOT }}/outputs
|
|
81
.github/workflows/test-invoke-pip.yml
vendored
81
.github/workflows/test-invoke-pip.yml
vendored
@ -4,8 +4,6 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
- 'opened'
|
- 'opened'
|
||||||
@ -17,14 +15,14 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
fail_if_pull_request_is_draft:
|
# fail_if_pull_request_is_draft:
|
||||||
if: github.event.pull_request.draft == true
|
# if: github.event.pull_request.draft == true && github.head_ref != 'dev/diffusers'
|
||||||
runs-on: ubuntu-18.04
|
# runs-on: ubuntu-18.04
|
||||||
steps:
|
# steps:
|
||||||
- name: Fails in order to indicate that pull request needs to be marked as ready to review and unit tests workflow needs to pass.
|
# - name: Fails in order to indicate that pull request needs to be marked as ready to review and unit tests workflow needs to pass.
|
||||||
run: exit 1
|
# run: exit 1
|
||||||
matrix:
|
matrix:
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false || github.head_ref == 'dev/diffusers'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
stable-diffusion-model:
|
stable-diffusion-model:
|
||||||
@ -40,26 +38,23 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- requirements-file: requirements-lin-cuda.txt
|
- requirements-file: requirements-lin-cuda.txt
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
curl-command: curl
|
|
||||||
github-env: $GITHUB_ENV
|
github-env: $GITHUB_ENV
|
||||||
- requirements-file: requirements-lin-amd.txt
|
- requirements-file: requirements-lin-amd.txt
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
curl-command: curl
|
|
||||||
github-env: $GITHUB_ENV
|
github-env: $GITHUB_ENV
|
||||||
- requirements-file: requirements-mac-mps-cpu.txt
|
- requirements-file: requirements-mac-mps-cpu.txt
|
||||||
os: macOS-12
|
os: macOS-12
|
||||||
curl-command: curl
|
|
||||||
github-env: $GITHUB_ENV
|
github-env: $GITHUB_ENV
|
||||||
- requirements-file: requirements-win-colab-cuda.txt
|
- requirements-file: requirements-win-colab-cuda.txt
|
||||||
os: windows-2022
|
os: windows-2022
|
||||||
curl-command: curl.exe
|
|
||||||
github-env: $env:GITHUB_ENV
|
github-env: $env:GITHUB_ENV
|
||||||
- stable-diffusion-model: stable-diffusion-1.5
|
|
||||||
stable-diffusion-model-url: https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
|
|
||||||
stable-diffusion-model-dl-path: models/ldm/stable-diffusion-v1
|
|
||||||
stable-diffusion-model-dl-name: v1-5-pruned-emaonly.ckpt
|
|
||||||
name: ${{ matrix.requirements-file }} on ${{ matrix.python-version }}
|
name: ${{ matrix.requirements-file }} on ${{ matrix.python-version }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
env:
|
||||||
|
INVOKE_MODEL_RECONFIGURE: '--yes'
|
||||||
|
INVOKEAI_ROOT: '${{ github.workspace }}/invokeai'
|
||||||
|
PYTHONUNBUFFERED: 1
|
||||||
|
HAVE_SECRETS: ${{ secrets.HUGGINGFACE_TOKEN != '' }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout sources
|
- name: Checkout sources
|
||||||
id: checkout-sources
|
id: checkout-sources
|
||||||
@ -77,10 +72,17 @@ jobs:
|
|||||||
echo "INVOKEAI_ROOT=${{ github.workspace }}/invokeai" >> ${{ matrix.github-env }}
|
echo "INVOKEAI_ROOT=${{ github.workspace }}/invokeai" >> ${{ matrix.github-env }}
|
||||||
echo "INVOKEAI_OUTDIR=${{ github.workspace }}/invokeai/outputs" >> ${{ matrix.github-env }}
|
echo "INVOKEAI_OUTDIR=${{ github.workspace }}/invokeai/outputs" >> ${{ matrix.github-env }}
|
||||||
|
|
||||||
- name: create models.yaml from example
|
- name: Use Cached diffusers-1.5
|
||||||
run: |
|
id: cache-sd-model
|
||||||
mkdir -p ${{ env.INVOKEAI_ROOT }}/configs
|
uses: actions/cache@v3
|
||||||
cp configs/models.yaml.example ${{ env.INVOKEAI_ROOT }}/configs/models.yaml
|
env:
|
||||||
|
cache-name: huggingface-${{ matrix.stable-diffusion-model }}
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.INVOKEAI_ROOT }}/models/runwayml
|
||||||
|
${{ env.INVOKEAI_ROOT }}/models/stabilityai
|
||||||
|
${{ env.INVOKEAI_ROOT }}/models/CompVis
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
|
||||||
- name: set test prompt to main branch validation
|
- name: set test prompt to main branch validation
|
||||||
if: ${{ github.ref == 'refs/heads/main' }}
|
if: ${{ github.ref == 'refs/heads/main' }}
|
||||||
@ -110,30 +112,31 @@ jobs:
|
|||||||
- name: install requirements
|
- name: install requirements
|
||||||
run: pip3 install -r '${{ matrix.requirements-file }}'
|
run: pip3 install -r '${{ matrix.requirements-file }}'
|
||||||
|
|
||||||
- name: Use Cached Stable Diffusion Model
|
|
||||||
id: cache-sd-model
|
|
||||||
uses: actions/cache@v3
|
|
||||||
env:
|
|
||||||
cache-name: cache-${{ matrix.stable-diffusion-model }}
|
|
||||||
with:
|
|
||||||
path: ${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}
|
|
||||||
key: ${{ env.cache-name }}
|
|
||||||
|
|
||||||
- name: Download ${{ matrix.stable-diffusion-model }}
|
|
||||||
id: download-stable-diffusion-model
|
|
||||||
if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }}
|
|
||||||
run: |
|
|
||||||
mkdir -p "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}"
|
|
||||||
${{ matrix.curl-command }} -H "Authorization: Bearer ${{ secrets.HUGGINGFACE_TOKEN }}" -o "${{ env.INVOKEAI_ROOT }}/${{ matrix.stable-diffusion-model-dl-path }}/${{ matrix.stable-diffusion-model-dl-name }}" -L ${{ matrix.stable-diffusion-model-url }}
|
|
||||||
|
|
||||||
- name: run configure_invokeai.py
|
- name: run configure_invokeai.py
|
||||||
id: run-preload-models
|
id: run-preload-models
|
||||||
run: python3 scripts/configure_invokeai.py --skip-sd-weights --yes
|
env:
|
||||||
|
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||||
|
run: >
|
||||||
|
configure_invokeai.py
|
||||||
|
--yes
|
||||||
|
--full-precision # can't use fp16 weights without a GPU
|
||||||
|
|
||||||
- name: Run the tests
|
- name: Run the tests
|
||||||
id: run-tests
|
id: run-tests
|
||||||
if: matrix.os != 'windows-2022'
|
if: matrix.os != 'windows-2022'
|
||||||
run: python3 scripts/invoke.py --no-patchmatch --no-nsfw_checker --model ${{ matrix.stable-diffusion-model }} --from_file ${{ env.TEST_PROMPTS }} --root="${{ env.INVOKEAI_ROOT }}" --outdir="${{ env.INVOKEAI_OUTDIR }}"
|
env:
|
||||||
|
# Set offline mode to make sure configure preloaded successfully.
|
||||||
|
HF_HUB_OFFLINE: 1
|
||||||
|
HF_DATASETS_OFFLINE: 1
|
||||||
|
TRANSFORMERS_OFFLINE: 1
|
||||||
|
run: >
|
||||||
|
python3 scripts/invoke.py
|
||||||
|
--no-patchmatch
|
||||||
|
--no-nsfw_checker
|
||||||
|
--model ${{ matrix.stable-diffusion-model }}
|
||||||
|
--from_file ${{ env.TEST_PROMPTS }}
|
||||||
|
--root="${{ env.INVOKEAI_ROOT }}"
|
||||||
|
--outdir="${{ env.INVOKEAI_OUTDIR }}"
|
||||||
|
|
||||||
- name: Archive results
|
- name: Archive results
|
||||||
id: archive-results
|
id: archive-results
|
||||||
|
@ -100,7 +100,6 @@ to render 512x512 images.
|
|||||||
|
|
||||||
- At least 12 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
- At least 12 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
||||||
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
Feature documentation can be reviewed by navigating to [the InvokeAI Documentation page](https://invoke-ai.github.io/InvokeAI/features/)
|
Feature documentation can be reviewed by navigating to [the InvokeAI Documentation page](https://invoke-ai.github.io/InvokeAI/features/)
|
||||||
@ -118,6 +117,8 @@ InvokeAI's advanced prompt syntax allows for token weighting, cross-attention co
|
|||||||
For users utilizing a terminal-based environment, or who want to take advantage of CLI features, InvokeAI offers an extensive and actively supported command-line interface that provides the full suite of generation functionality available in the tool.
|
For users utilizing a terminal-based environment, or who want to take advantage of CLI features, InvokeAI offers an extensive and actively supported command-line interface that provides the full suite of generation functionality available in the tool.
|
||||||
|
|
||||||
### Other features
|
### Other features
|
||||||
|
- *Support for both ckpt and diffusers models*
|
||||||
|
- *SD 2.0, 2.1 support*
|
||||||
- *Noise Control & Tresholding*
|
- *Noise Control & Tresholding*
|
||||||
- *Popular Sampler Support*
|
- *Popular Sampler Support*
|
||||||
- *Upscaling & Face Restoration Tools*
|
- *Upscaling & Face Restoration Tools*
|
||||||
@ -125,14 +126,14 @@ For users utilizing a terminal-based environment, or who want to take advantage
|
|||||||
- *Model Manager & Support*
|
- *Model Manager & Support*
|
||||||
|
|
||||||
### Coming Soon
|
### Coming Soon
|
||||||
- *2.0/2.1 Model Support*
|
|
||||||
- *Depth2Img Support*
|
|
||||||
- *Node-Based Architecture & UI*
|
- *Node-Based Architecture & UI*
|
||||||
- And more...
|
- And more...
|
||||||
|
|
||||||
### Latest Changes
|
### Latest Changes
|
||||||
|
|
||||||
For our latest changes, view our [Release Notes](https://github.com/invoke-ai/InvokeAI/releases)
|
For our latest changes, view our [Release
|
||||||
|
Notes](https://github.com/invoke-ai/InvokeAI/releases) and the
|
||||||
|
[CHANGELOG](docs/CHANGELOG.md).
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
|
@ -1,35 +1,34 @@
|
|||||||
import eventlet
|
import base64
|
||||||
import glob
|
import glob
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import mimetypes
|
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
from threading import Event
|
||||||
import io
|
from uuid import uuid4
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
from werkzeug.utils import secure_filename
|
import eventlet
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as ImageType
|
||||||
from flask import Flask, redirect, send_from_directory, request, make_response
|
from flask import Flask, redirect, send_from_directory, request, make_response
|
||||||
from flask_socketio import SocketIO
|
from flask_socketio import SocketIO
|
||||||
from PIL import Image, ImageOps
|
from werkzeug.utils import secure_filename
|
||||||
from PIL.Image import Image as ImageType
|
|
||||||
from uuid import uuid4
|
|
||||||
from threading import Event
|
|
||||||
|
|
||||||
from ldm.generate import Generate
|
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
|
||||||
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
|
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
|
||||||
|
|
||||||
from backend.modules.parameters import parameters_to_command
|
|
||||||
from backend.modules.get_canvas_generation_mode import (
|
from backend.modules.get_canvas_generation_mode import (
|
||||||
get_canvas_generation_mode,
|
get_canvas_generation_mode,
|
||||||
)
|
)
|
||||||
|
from backend.modules.parameters import parameters_to_command
|
||||||
|
from ldm.generate import Generate
|
||||||
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
|
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||||
|
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
||||||
|
|
||||||
# Loading Arguments
|
# Loading Arguments
|
||||||
opt = Args()
|
opt = Args()
|
||||||
@ -304,7 +303,7 @@ class InvokeAIWebServer:
|
|||||||
def handle_request_capabilities():
|
def handle_request_capabilities():
|
||||||
print(f">> System config requested")
|
print(f">> System config requested")
|
||||||
config = self.get_system_config()
|
config = self.get_system_config()
|
||||||
config["model_list"] = self.generate.model_cache.list_models()
|
config["model_list"] = self.generate.model_manager.list_models()
|
||||||
config["infill_methods"] = infill_methods()
|
config["infill_methods"] = infill_methods()
|
||||||
socketio.emit("systemConfig", config)
|
socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
@ -317,7 +316,7 @@ class InvokeAIWebServer:
|
|||||||
{'search_folder': None, 'found_models': None},
|
{'search_folder': None, 'found_models': None},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
search_folder, found_models = self.generate.model_cache.search_models(search_folder)
|
search_folder, found_models = self.generate.model_manager.search_models(search_folder)
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"foundModels",
|
"foundModels",
|
||||||
{'search_folder': search_folder, 'found_models': found_models},
|
{'search_folder': search_folder, 'found_models': found_models},
|
||||||
@ -335,18 +334,20 @@ class InvokeAIWebServer:
|
|||||||
model_name = new_model_config['name']
|
model_name = new_model_config['name']
|
||||||
del new_model_config['name']
|
del new_model_config['name']
|
||||||
model_attributes = new_model_config
|
model_attributes = new_model_config
|
||||||
|
if len(model_attributes['vae']) == 0:
|
||||||
|
del model_attributes['vae']
|
||||||
update = False
|
update = False
|
||||||
current_model_list = self.generate.model_cache.list_models()
|
current_model_list = self.generate.model_manager.list_models()
|
||||||
if model_name in current_model_list:
|
if model_name in current_model_list:
|
||||||
update = True
|
update = True
|
||||||
|
|
||||||
print(f">> Adding New Model: {model_name}")
|
print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
self.generate.model_cache.add_model(
|
self.generate.model_manager.add_model(
|
||||||
model_name=model_name, model_attributes=model_attributes, clobber=True)
|
model_name=model_name, model_attributes=model_attributes, clobber=True)
|
||||||
self.generate.model_cache.commit(opt.conf)
|
self.generate.model_manager.commit(opt.conf)
|
||||||
|
|
||||||
new_model_list = self.generate.model_cache.list_models()
|
new_model_list = self.generate.model_manager.list_models()
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"newModelAdded",
|
"newModelAdded",
|
||||||
{"new_model_name": model_name,
|
{"new_model_name": model_name,
|
||||||
@ -364,9 +365,9 @@ class InvokeAIWebServer:
|
|||||||
def handle_delete_model(model_name: str):
|
def handle_delete_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
print(f">> Deleting Model: {model_name}")
|
print(f">> Deleting Model: {model_name}")
|
||||||
self.generate.model_cache.del_model(model_name)
|
self.generate.model_manager.del_model(model_name)
|
||||||
self.generate.model_cache.commit(opt.conf)
|
self.generate.model_manager.commit(opt.conf)
|
||||||
updated_model_list = self.generate.model_cache.list_models()
|
updated_model_list = self.generate.model_manager.list_models()
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"modelDeleted",
|
"modelDeleted",
|
||||||
{"deleted_model_name": model_name,
|
{"deleted_model_name": model_name,
|
||||||
@ -385,7 +386,7 @@ class InvokeAIWebServer:
|
|||||||
try:
|
try:
|
||||||
print(f">> Model change requested: {model_name}")
|
print(f">> Model change requested: {model_name}")
|
||||||
model = self.generate.set_model(model_name)
|
model = self.generate.set_model(model_name)
|
||||||
model_list = self.generate.model_cache.list_models()
|
model_list = self.generate.model_manager.list_models()
|
||||||
if model is None:
|
if model is None:
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"modelChangeFailed",
|
"modelChangeFailed",
|
||||||
@ -797,7 +798,7 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
# App Functions
|
# App Functions
|
||||||
def get_system_config(self):
|
def get_system_config(self):
|
||||||
model_list: dict = self.generate.model_cache.list_models()
|
model_list: dict = self.generate.model_manager.list_models()
|
||||||
active_model_name = None
|
active_model_name = None
|
||||||
|
|
||||||
for model_name, model_dict in model_list.items():
|
for model_name, model_dict in model_list.items():
|
||||||
@ -1205,9 +1206,16 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
print(generation_parameters)
|
print(generation_parameters)
|
||||||
|
|
||||||
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||||
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||||
|
progress_state: PipelineIntermediateState = cb_args[0]
|
||||||
|
return image_progress(progress_state.latents, progress_state.step)
|
||||||
|
else:
|
||||||
|
return image_progress(*cb_args, **kwargs)
|
||||||
|
|
||||||
self.generate.prompt2image(
|
self.generate.prompt2image(
|
||||||
**generation_parameters,
|
**generation_parameters,
|
||||||
step_callback=image_progress,
|
step_callback=diffusers_step_callback_adapter,
|
||||||
image_callback=image_done
|
image_callback=image_done
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@ SAMPLER_CHOICES = [
|
|||||||
"k_heun",
|
"k_heun",
|
||||||
"k_lms",
|
"k_lms",
|
||||||
"plms",
|
"plms",
|
||||||
|
# diffusers:
|
||||||
|
"pndm",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
||||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||||
--trusted-host https://download.pytorch.org
|
--trusted-host https://download.pytorch.org
|
||||||
accelerate~=0.14
|
accelerate~=0.15
|
||||||
albumentations
|
albumentations
|
||||||
diffusers
|
diffusers[torch]~=0.11
|
||||||
|
einops
|
||||||
eventlet
|
eventlet
|
||||||
flask_cors
|
flask_cors
|
||||||
flask_socketio
|
flask_socketio
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
|
stable-diffusion-2.1:
|
||||||
|
description: Stable Diffusion version 2.1 diffusers model (5.21 GB)
|
||||||
|
repo_id: stabilityai/stable-diffusion-2-1
|
||||||
|
format: diffusers
|
||||||
|
recommended: True
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
description: Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||||
repo_id: runwayml/stable-diffusion-v1-5
|
repo_id: runwayml/stable-diffusion-v1-5
|
||||||
config: v1-inference.yaml
|
format: diffusers
|
||||||
file: v1-5-pruned-emaonly.ckpt
|
recommended: True
|
||||||
recommended: true
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
inpainting-1.5:
|
inpainting-1.5:
|
||||||
@ -11,23 +17,20 @@ inpainting-1.5:
|
|||||||
repo_id: runwayml/stable-diffusion-inpainting
|
repo_id: runwayml/stable-diffusion-inpainting
|
||||||
config: v1-inpainting-inference.yaml
|
config: v1-inpainting-inference.yaml
|
||||||
file: sd-v1-5-inpainting.ckpt
|
file: sd-v1-5-inpainting.ckpt
|
||||||
recommended: True
|
format: ckpt
|
||||||
width: 512
|
vae:
|
||||||
height: 512
|
|
||||||
ft-mse-improved-autoencoder-840000:
|
|
||||||
description: StabilityAI improved autoencoder fine-tuned for human faces (recommended; 335 MB)
|
|
||||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||||
config: VAE/default
|
|
||||||
file: vae-ft-mse-840000-ema-pruned.ckpt
|
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
recommended: True
|
recommended: True
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
description: The original Stable Diffusion version 1.4 weight file (4.27 GB)
|
description: The original Stable Diffusion version 1.4 weight file (4.27 GB)
|
||||||
repo_id: CompVis/stable-diffusion-v-1-4-original
|
repo_id: CompVis/stable-diffusion-v1-4
|
||||||
config: v1-inference.yaml
|
|
||||||
file: sd-v1-4.ckpt
|
|
||||||
recommended: False
|
recommended: False
|
||||||
|
format: diffusers
|
||||||
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
waifu-diffusion-1.3:
|
waifu-diffusion-1.3:
|
||||||
@ -35,29 +38,30 @@ waifu-diffusion-1.3:
|
|||||||
repo_id: hakurei/waifu-diffusion-v1-3
|
repo_id: hakurei/waifu-diffusion-v1-3
|
||||||
config: v1-inference.yaml
|
config: v1-inference.yaml
|
||||||
file: model-epoch09-float32.ckpt
|
file: model-epoch09-float32.ckpt
|
||||||
|
format: ckpt
|
||||||
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||||
|
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
recommended: False
|
recommended: False
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
trinart-2.0:
|
trinart-2.0:
|
||||||
description: An SD model finetuned with ~40,000 assorted high resolution manga/anime-style pictures (2.13 GB)
|
description: An SD model finetuned with ~40,000 assorted high resolution manga/anime-style pictures (2.13 GB)
|
||||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
repo_id: naclbit/trinart_stable_diffusion_v2
|
||||||
config: v1-inference.yaml
|
format: diffusers
|
||||||
file: trinart2_step95000.ckpt
|
|
||||||
recommended: False
|
recommended: False
|
||||||
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
trinart_characters-1.0:
|
trinart_characters-2.0:
|
||||||
description: An SD model finetuned with 19.2M anime/manga style images (2.13 GB)
|
description: An SD model finetuned with 19.2M anime/manga style images (4.27 GB)
|
||||||
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
repo_id: naclbit/trinart_derrida_characters_v2_stable_diffusion
|
||||||
config: v1-inference.yaml
|
config: v1-inference.yaml
|
||||||
file: trinart_characters_it4_v1.ckpt
|
file: derrida_final.ckpt
|
||||||
recommended: False
|
format: ckpt
|
||||||
width: 512
|
vae:
|
||||||
height: 512
|
repo_id: naclbit/trinart_derrida_characters_v2_stable_diffusion
|
||||||
trinart_vae:
|
|
||||||
description: Custom autoencoder for trinart_characters
|
|
||||||
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
|
||||||
config: VAE/trinart
|
|
||||||
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||||
recommended: False
|
recommended: False
|
||||||
width: 512
|
width: 512
|
||||||
@ -65,8 +69,9 @@ trinart_vae:
|
|||||||
papercut-1.0:
|
papercut-1.0:
|
||||||
description: SD 1.5 fine-tuned for papercut art (use "PaperCut" in your prompts) (2.13 GB)
|
description: SD 1.5 fine-tuned for papercut art (use "PaperCut" in your prompts) (2.13 GB)
|
||||||
repo_id: Fictiverse/Stable_Diffusion_PaperCut_Model
|
repo_id: Fictiverse/Stable_Diffusion_PaperCut_Model
|
||||||
config: v1-inference.yaml
|
format: diffusers
|
||||||
file: PaperCut_v1.ckpt
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
recommended: False
|
recommended: False
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
@ -75,6 +80,27 @@ voxel_art-1.0:
|
|||||||
repo_id: Fictiverse/Stable_Diffusion_VoxelArt_Model
|
repo_id: Fictiverse/Stable_Diffusion_VoxelArt_Model
|
||||||
config: v1-inference.yaml
|
config: v1-inference.yaml
|
||||||
file: VoxelArt_v1.ckpt
|
file: VoxelArt_v1.ckpt
|
||||||
|
format: ckpt
|
||||||
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
|
recommended: False
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
ft-mse-improved-autoencoder-840000:
|
||||||
|
description: StabilityAI improved autoencoder fine-tuned for human faces. Use with legacy .ckpt models ONLY (335 MB)
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||||
|
format: ckpt
|
||||||
|
config: VAE/default
|
||||||
|
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
recommended: False
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
trinart_vae:
|
||||||
|
description: Custom autoencoder for trinart_characters for legacy .ckpt models only (335 MB)
|
||||||
|
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
||||||
|
config: VAE/trinart
|
||||||
|
format: ckpt
|
||||||
|
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||||
recommended: False
|
recommended: False
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
@ -5,6 +5,25 @@
|
|||||||
# model requires a model config file, a weights file,
|
# model requires a model config file, a weights file,
|
||||||
# and the width and height of the images it
|
# and the width and height of the images it
|
||||||
# was trained on.
|
# was trained on.
|
||||||
|
diffusers-1.4:
|
||||||
|
description: 🤗🧨 Stable Diffusion v1.4
|
||||||
|
format: diffusers
|
||||||
|
repo_id: CompVis/stable-diffusion-v1-4
|
||||||
|
diffusers-1.5:
|
||||||
|
description: 🤗🧨 Stable Diffusion v1.5
|
||||||
|
format: diffusers
|
||||||
|
repo_id: runwayml/stable-diffusion-v1-5
|
||||||
|
default: true
|
||||||
|
diffusers-1.5+mse:
|
||||||
|
description: 🤗🧨 Stable Diffusion v1.5 + MSE-finetuned VAE
|
||||||
|
format: diffusers
|
||||||
|
repo_id: runwayml/stable-diffusion-v1-5
|
||||||
|
vae:
|
||||||
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
|
diffusers-inpainting-1.5:
|
||||||
|
description: 🤗🧨 inpainting for Stable Diffusion v1.5
|
||||||
|
format: diffusers
|
||||||
|
repo_id: runwayml/stable-diffusion-inpainting
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
description: The newest Stable Diffusion version 1.5 weight file (4.27 GB)
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
@ -12,7 +31,6 @@ stable-diffusion-1.5:
|
|||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
default: true
|
|
||||||
stable-diffusion-1.4:
|
stable-diffusion-1.4:
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
|
68
configs/stable-diffusion/v2-inference-v.yaml
Normal file
68
configs/stable-diffusion/v2-inference-v.yaml
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-4
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False # we set this to false because this is an inference only config
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
use_fp16: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
@ -4,6 +4,97 @@ title: Changelog
|
|||||||
|
|
||||||
# :octicons-log-16: **Changelog**
|
# :octicons-log-16: **Changelog**
|
||||||
|
|
||||||
|
## v2.3.0 <small>(15 January 2023)</small>
|
||||||
|
|
||||||
|
**Transition to diffusers
|
||||||
|
|
||||||
|
Version 2.3 provides support for both the traditional `.ckpt` weight
|
||||||
|
checkpoint files as well as the HuggingFace `diffusers` format. This
|
||||||
|
introduces several changes you should know about.
|
||||||
|
|
||||||
|
1. The models.yaml format has been updated. There are now two
|
||||||
|
different type of configuration stanza. The traditional ckpt
|
||||||
|
one will look like this, with a `format` of `ckpt` and a
|
||||||
|
`weights` field that points to the absolute or ROOTDIR-relative
|
||||||
|
location of the ckpt file.
|
||||||
|
|
||||||
|
```
|
||||||
|
inpainting-1.5:
|
||||||
|
description: RunwayML SD 1.5 model optimized for inpainting (4.27 GB)
|
||||||
|
repo_id: runwayml/stable-diffusion-inpainting
|
||||||
|
format: ckpt
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||||
|
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||||
|
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
A configuration stanza for a diffusers model hosted at HuggingFace will look like this,
|
||||||
|
with a `format` of `diffusers` and a `repo_id` that points to the
|
||||||
|
repository ID of the model on HuggingFace:
|
||||||
|
|
||||||
|
```
|
||||||
|
stable-diffusion-2.1:
|
||||||
|
description: Stable Diffusion version 2.1 diffusers model (5.21 GB)
|
||||||
|
repo_id: stabilityai/stable-diffusion-2-1
|
||||||
|
format: diffusers
|
||||||
|
```
|
||||||
|
|
||||||
|
A configuration stanza for a diffuers model stored locally should
|
||||||
|
look like this, with a `format` of `diffusers`, but a `path` field
|
||||||
|
that points at the directory that contains `model_index.json`:
|
||||||
|
|
||||||
|
```
|
||||||
|
waifu-diffusion:
|
||||||
|
description: Latest waifu diffusion 1.4
|
||||||
|
format: diffusers
|
||||||
|
path: models/diffusers/hakurei-haifu-diffusion-1.4
|
||||||
|
```
|
||||||
|
|
||||||
|
2. The format of the models directory has changed to mimic the
|
||||||
|
HuggingFace cache directory. By default, diffusers models are
|
||||||
|
now automatically downloaded and retrieved from the directory
|
||||||
|
`ROOTDIR/models/diffusers`, while other models are stored in
|
||||||
|
the directory `ROOTDIR/models/hub`. This organization is the
|
||||||
|
same as that used by HuggingFace for its cache management.
|
||||||
|
|
||||||
|
This allows you to share diffusers and ckpt model files easily with
|
||||||
|
other machine learning applications that use the HuggingFace
|
||||||
|
libraries. To do this, set the environment variable HF_HOME
|
||||||
|
before starting up InvokeAI to tell it what directory to
|
||||||
|
cache models in. To tell InvokeAI to use the standard HuggingFace
|
||||||
|
cache directory, you would set HF_HOME like this (Linux/Mac):
|
||||||
|
|
||||||
|
`export HF_HOME=~/.cache/hugging_face`
|
||||||
|
|
||||||
|
3. If you upgrade to InvokeAI 2.3.* from an earlier version, there
|
||||||
|
will be a one-time migration from the old models directory format
|
||||||
|
to the new one. You will see a message about this the first time
|
||||||
|
you start `invoke.py`.
|
||||||
|
|
||||||
|
4. Both the front end back ends of the model manager have been
|
||||||
|
rewritten to accommodate diffusers. You can import models using
|
||||||
|
their local file path, using their URLs, or their HuggingFace
|
||||||
|
repo_ids. On the command line, all these syntaxes work:
|
||||||
|
|
||||||
|
```
|
||||||
|
!import_model stabilityai/stable-diffusion-2-1-base
|
||||||
|
!import_model /opt/sd-models/sd-1.4.ckpt
|
||||||
|
!import_model https://huggingface.co/Fictiverse/Stable_Diffusion_PaperCut_Model/blob/main/PaperCut_v1.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
**KNOWN BUGS (15 January 2023)
|
||||||
|
|
||||||
|
1. On CUDA systems, the 768 pixel stable-diffusion-2.0 and
|
||||||
|
stable-diffusion-2.1 models can only be run as `diffusers` models
|
||||||
|
when the `xformer` library is installed and configured. Without
|
||||||
|
`xformers`, InvokeAI returns black images.
|
||||||
|
|
||||||
|
2. Inpainting and outpainting have regressed in quality.
|
||||||
|
|
||||||
|
Both these issues are being actively worked on.
|
||||||
|
|
||||||
## v2.2.4 <small>(11 December 2022)</small>
|
## v2.2.4 <small>(11 December 2022)</small>
|
||||||
|
|
||||||
**the `invokeai` directory**
|
**the `invokeai` directory**
|
||||||
|
@ -28,13 +28,18 @@ dependencies:
|
|||||||
- torch-fidelity=0.3.0
|
- torch-fidelity=0.3.0
|
||||||
- torchmetrics=0.7.0
|
- torchmetrics=0.7.0
|
||||||
- torchvision
|
- torchvision
|
||||||
- transformers=4.21.3
|
- transformers~=4.25
|
||||||
- pip:
|
- pip:
|
||||||
|
- accelerate
|
||||||
|
- diffusers[torch]~=0.11
|
||||||
- getpass_asterisk
|
- getpass_asterisk
|
||||||
|
- huggingface-hub>=0.11.1
|
||||||
- omegaconf==2.1.1
|
- omegaconf==2.1.1
|
||||||
- picklescan
|
- picklescan
|
||||||
- pyreadline3
|
- pyreadline3
|
||||||
- realesrgan
|
- realesrgan
|
||||||
|
- requests==2.25.1
|
||||||
|
- safetensors
|
||||||
- taming-transformers-rom1504
|
- taming-transformers-rom1504
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
|
@ -9,14 +9,16 @@ dependencies:
|
|||||||
- numpy=1.23.3
|
- numpy=1.23.3
|
||||||
- pip:
|
- pip:
|
||||||
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/
|
- --extra-index-url https://download.pytorch.org/whl/rocm5.2/
|
||||||
|
- accelerate
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- diffusers==0.6.0
|
- diffusers[torch]~=0.11
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- eventlet
|
- eventlet
|
||||||
- flask==2.1.3
|
- flask==2.1.3
|
||||||
- flask_cors==3.0.10
|
- flask_cors==3.0.10
|
||||||
- flask_socketio==5.3.0
|
- flask_socketio==5.3.0
|
||||||
- getpass_asterisk
|
- getpass_asterisk
|
||||||
|
- huggingface-hub>=0.11.1
|
||||||
- imageio-ffmpeg==0.4.2
|
- imageio-ffmpeg==0.4.2
|
||||||
- imageio==2.9.0
|
- imageio==2.9.0
|
||||||
- kornia==0.6.0
|
- kornia==0.6.0
|
||||||
@ -28,6 +30,8 @@ dependencies:
|
|||||||
- pyreadline3
|
- pyreadline3
|
||||||
- pytorch-lightning==1.7.7
|
- pytorch-lightning==1.7.7
|
||||||
- realesrgan
|
- realesrgan
|
||||||
|
- requests==2.25.1
|
||||||
|
- safetensors
|
||||||
- send2trash==1.8.0
|
- send2trash==1.8.0
|
||||||
- streamlit==1.12.0
|
- streamlit==1.12.0
|
||||||
- taming-transformers-rom1504
|
- taming-transformers-rom1504
|
||||||
@ -38,7 +42,7 @@ dependencies:
|
|||||||
- torchaudio
|
- torchaudio
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- torchvision
|
- torchvision
|
||||||
- transformers==4.21.3
|
- transformers~=4.25
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||||
|
@ -12,14 +12,16 @@ dependencies:
|
|||||||
- pytorch=1.12.1
|
- pytorch=1.12.1
|
||||||
- cudatoolkit=11.6
|
- cudatoolkit=11.6
|
||||||
- pip:
|
- pip:
|
||||||
|
- accelerate~=0.13
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- diffusers==0.6.0
|
- diffusers[torch]~=0.11
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- eventlet
|
- eventlet
|
||||||
- flask==2.1.3
|
- flask==2.1.3
|
||||||
- flask_cors==3.0.10
|
- flask_cors==3.0.10
|
||||||
- flask_socketio==5.3.0
|
- flask_socketio==5.3.0
|
||||||
- getpass_asterisk
|
- getpass_asterisk
|
||||||
|
- huggingface-hub>=0.11.1
|
||||||
- imageio-ffmpeg==0.4.2
|
- imageio-ffmpeg==0.4.2
|
||||||
- imageio==2.9.0
|
- imageio==2.9.0
|
||||||
- kornia==0.6.0
|
- kornia==0.6.0
|
||||||
@ -31,13 +33,15 @@ dependencies:
|
|||||||
- pyreadline3
|
- pyreadline3
|
||||||
- pytorch-lightning==1.7.7
|
- pytorch-lightning==1.7.7
|
||||||
- realesrgan
|
- realesrgan
|
||||||
|
- requests==2.25.1
|
||||||
|
- safetensors~=0.2
|
||||||
- send2trash==1.8.0
|
- send2trash==1.8.0
|
||||||
- streamlit==1.12.0
|
- streamlit==1.12.0
|
||||||
- taming-transformers-rom1504
|
- taming-transformers-rom1504
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- transformers==4.21.3
|
- transformers~=4.25
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
name: invokeai
|
name: invokeai
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
- pytorch
|
||||||
|
- huggingface
|
||||||
- conda-forge
|
- conda-forge
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -19,10 +20,9 @@ dependencies:
|
|||||||
# sed -E 's/invokeai/invokeai-updated/;20,99s/- ([^=]+)==.+/- \1/' environment-mac.yml > environment-mac-updated.yml
|
# sed -E 's/invokeai/invokeai-updated/;20,99s/- ([^=]+)==.+/- \1/' environment-mac.yml > environment-mac-updated.yml
|
||||||
# CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac-updated.yml && conda list -n invokeai-updated | awk ' {print " - " $1 "==" $2;} '
|
# CONDA_SUBDIR=osx-arm64 conda env create -f environment-mac-updated.yml && conda list -n invokeai-updated | awk ' {print " - " $1 "==" $2;} '
|
||||||
# ```
|
# ```
|
||||||
|
- accelerate
|
||||||
- albumentations=1.2
|
- albumentations=1.2
|
||||||
- coloredlogs=15.0
|
- coloredlogs=15.0
|
||||||
- diffusers=0.6
|
|
||||||
- einops=0.3
|
- einops=0.3
|
||||||
- eventlet
|
- eventlet
|
||||||
- grpcio=1.46
|
- grpcio=1.46
|
||||||
@ -49,10 +49,14 @@ dependencies:
|
|||||||
- sympy=1.10
|
- sympy=1.10
|
||||||
- send2trash=1.8
|
- send2trash=1.8
|
||||||
- tensorboard=2.10
|
- tensorboard=2.10
|
||||||
- transformers=4.23
|
- transformers~=4.25
|
||||||
- pip:
|
- pip:
|
||||||
|
- diffusers[torch]~=0.11
|
||||||
|
- safetensors~=0.2
|
||||||
- getpass_asterisk
|
- getpass_asterisk
|
||||||
|
- huggingface-hub
|
||||||
- picklescan
|
- picklescan
|
||||||
|
- requests==2.25.1
|
||||||
- taming-transformers-rom1504
|
- taming-transformers-rom1504
|
||||||
- test-tube==0.7.5
|
- test-tube==0.7.5
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
|
@ -12,14 +12,16 @@ dependencies:
|
|||||||
- pytorch=1.12.1
|
- pytorch=1.12.1
|
||||||
- cudatoolkit=11.6
|
- cudatoolkit=11.6
|
||||||
- pip:
|
- pip:
|
||||||
|
- accelerate
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
- diffusers==0.6.0
|
- diffusers[torch]~=0.11
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- eventlet
|
- eventlet
|
||||||
- flask==2.1.3
|
- flask==2.1.3
|
||||||
- flask_cors==3.0.10
|
- flask_cors==3.0.10
|
||||||
- flask_socketio==5.3.0
|
- flask_socketio==5.3.0
|
||||||
- getpass_asterisk
|
- getpass_asterisk
|
||||||
|
- huggingface-hub>=0.11.1
|
||||||
- imageio-ffmpeg==0.4.2
|
- imageio-ffmpeg==0.4.2
|
||||||
- imageio==2.9.0
|
- imageio==2.9.0
|
||||||
- kornia==0.6.0
|
- kornia==0.6.0
|
||||||
@ -31,13 +33,16 @@ dependencies:
|
|||||||
- pyreadline3
|
- pyreadline3
|
||||||
- pytorch-lightning==1.7.7
|
- pytorch-lightning==1.7.7
|
||||||
- realesrgan
|
- realesrgan
|
||||||
|
- requests==2.25.1
|
||||||
|
- safetensors
|
||||||
- send2trash==1.8.0
|
- send2trash==1.8.0
|
||||||
- streamlit==1.12.0
|
- streamlit==1.12.0
|
||||||
- taming-transformers-rom1504
|
- taming-transformers-rom1504
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- transformers==4.21.3
|
- transformers~=4.25
|
||||||
|
- windows-curses
|
||||||
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
- git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
- git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||||
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
- git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# pip will resolve the version which matches torch
|
# pip will resolve the version which matches torch
|
||||||
|
accelerate
|
||||||
albumentations
|
albumentations
|
||||||
diffusers==0.10.*
|
datasets
|
||||||
|
diffusers[torch]~=0.11
|
||||||
einops
|
einops
|
||||||
eventlet
|
eventlet
|
||||||
facexlib
|
facexlib
|
||||||
@ -14,6 +16,7 @@ huggingface-hub>=0.11.1
|
|||||||
imageio
|
imageio
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
kornia
|
kornia
|
||||||
|
npyscreen
|
||||||
numpy==1.23.*
|
numpy==1.23.*
|
||||||
omegaconf
|
omegaconf
|
||||||
opencv-python
|
opencv-python
|
||||||
@ -25,6 +28,7 @@ pyreadline3
|
|||||||
pytorch-lightning==1.7.7
|
pytorch-lightning==1.7.7
|
||||||
realesrgan
|
realesrgan
|
||||||
requests==2.25.1
|
requests==2.25.1
|
||||||
|
safetensors
|
||||||
scikit-image>=0.19
|
scikit-image>=0.19
|
||||||
send2trash
|
send2trash
|
||||||
streamlit
|
streamlit
|
||||||
@ -32,7 +36,8 @@ taming-transformers-rom1504
|
|||||||
test-tube>=0.7.5
|
test-tube>=0.7.5
|
||||||
torch-fidelity
|
torch-fidelity
|
||||||
torchmetrics
|
torchmetrics
|
||||||
transformers==4.25.*
|
transformers~=4.25
|
||||||
|
windows-curses; sys_platform == 'win32'
|
||||||
https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip#egg=k-diffusion
|
https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip#egg=k-diffusion
|
||||||
https://github.com/invoke-ai/PyPatchMatch/archive/refs/tags/0.1.5.zip#egg=pypatchmatch
|
https://github.com/invoke-ai/PyPatchMatch/archive/refs/tags/0.1.5.zip#egg=pypatchmatch
|
||||||
https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip#egg=clip
|
https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip#egg=clip
|
||||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
4
frontend/dist/index.html
vendored
4
frontend/dist/index.html
vendored
@ -7,7 +7,7 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||||
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
|
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
|
||||||
<script type="module" crossorigin src="./assets/index.ec2d89c6.js"></script>
|
<script type="module" crossorigin src="./assets/index.1b59e83a.js"></script>
|
||||||
<link rel="stylesheet" href="./assets/index.0dadf5d0.css">
|
<link rel="stylesheet" href="./assets/index.0dadf5d0.css">
|
||||||
<script type="module">try{import.meta.url;import("_").catch(()=>1);}catch(e){}window.__vite_is_modern_browser=true;</script>
|
<script type="module">try{import.meta.url;import("_").catch(()=>1);}catch(e){}window.__vite_is_modern_browser=true;</script>
|
||||||
<script type="module">!function(){if(window.__vite_is_modern_browser)return;console.warn("vite: loading legacy build because dynamic import or import.meta.url is unsupported, syntax error above should be ignored");var e=document.getElementById("vite-legacy-polyfill"),n=document.createElement("script");n.src=e.src,n.onload=function(){System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))},document.body.appendChild(n)}();</script>
|
<script type="module">!function(){if(window.__vite_is_modern_browser)return;console.warn("vite: loading legacy build because dynamic import or import.meta.url is unsupported, syntax error above should be ignored");var e=document.getElementById("vite-legacy-polyfill"),n=document.createElement("script");n.src=e.src,n.onload=function(){System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))},document.body.appendChild(n)}();</script>
|
||||||
@ -18,6 +18,6 @@
|
|||||||
|
|
||||||
<script nomodule>!function(){var e=document,t=e.createElement("script");if(!("noModule"in t)&&"onbeforeload"in t){var n=!1;e.addEventListener("beforeload",(function(e){if(e.target===t)n=!0;else if(!e.target.hasAttribute("nomodule")||!n)return;e.preventDefault()}),!0),t.type="module",t.src=".",e.head.appendChild(t),t.remove()}}();</script>
|
<script nomodule>!function(){var e=document,t=e.createElement("script");if(!("noModule"in t)&&"onbeforeload"in t){var n=!1;e.addEventListener("beforeload",(function(e){if(e.target===t)n=!0;else if(!e.target.hasAttribute("nomodule")||!n)return;e.preventDefault()}),!0),t.type="module",t.src=".",e.head.appendChild(t),t.remove()}}();</script>
|
||||||
<script nomodule crossorigin id="vite-legacy-polyfill" src="./assets/polyfills-legacy-dde3a68a.js"></script>
|
<script nomodule crossorigin id="vite-legacy-polyfill" src="./assets/polyfills-legacy-dde3a68a.js"></script>
|
||||||
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-5c5a479d.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
|
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-474a75fe.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
@ -17,6 +17,8 @@
|
|||||||
"langPortuguese": "Portuguese",
|
"langPortuguese": "Portuguese",
|
||||||
"langFrench": "French",
|
"langFrench": "French",
|
||||||
"langPolish": "Polish",
|
"langPolish": "Polish",
|
||||||
|
"langSimplifiedChinese": "Simplified Chinese",
|
||||||
|
"langSpanish": "Spanish",
|
||||||
"text2img": "Text To Image",
|
"text2img": "Text To Image",
|
||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
@ -32,6 +34,7 @@
|
|||||||
"upload": "Upload",
|
"upload": "Upload",
|
||||||
"close": "Close",
|
"close": "Close",
|
||||||
"load": "Load",
|
"load": "Load",
|
||||||
|
"back": "Back",
|
||||||
"statusConnected": "Connected",
|
"statusConnected": "Connected",
|
||||||
"statusDisconnected": "Disconnected",
|
"statusDisconnected": "Disconnected",
|
||||||
"statusError": "Error",
|
"statusError": "Error",
|
||||||
|
@ -34,6 +34,7 @@
|
|||||||
"upload": "Upload",
|
"upload": "Upload",
|
||||||
"close": "Close",
|
"close": "Close",
|
||||||
"load": "Load",
|
"load": "Load",
|
||||||
|
"back": "Back",
|
||||||
"statusConnected": "Connected",
|
"statusConnected": "Connected",
|
||||||
"statusDisconnected": "Disconnected",
|
"statusDisconnected": "Disconnected",
|
||||||
"statusError": "Error",
|
"statusError": "Error",
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
{
|
{
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
|
"allModels": "All Models",
|
||||||
|
"checkpointModels": "Checkpoints",
|
||||||
|
"diffusersModels": "Diffusers",
|
||||||
|
"safetensorModels": "SafeTensors",
|
||||||
"modelAdded": "Model Added",
|
"modelAdded": "Model Added",
|
||||||
"modelUpdated": "Model Updated",
|
"modelUpdated": "Model Updated",
|
||||||
"modelEntryDeleted": "Model Entry Deleted",
|
"modelEntryDeleted": "Model Entry Deleted",
|
||||||
"cannotUseSpaces": "Cannot Use Spaces",
|
"cannotUseSpaces": "Cannot Use Spaces",
|
||||||
"addNew": "Add New",
|
"addNew": "Add New",
|
||||||
"addNewModel": "Add New Model",
|
"addNewModel": "Add New Model",
|
||||||
|
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||||
|
"addDiffuserModel": "Add Diffusers",
|
||||||
"addManually": "Add Manually",
|
"addManually": "Add Manually",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
@ -17,8 +23,12 @@
|
|||||||
"configValidationMsg": "Path to the config file of your model.",
|
"configValidationMsg": "Path to the config file of your model.",
|
||||||
"modelLocation": "Model Location",
|
"modelLocation": "Model Location",
|
||||||
"modelLocationValidationMsg": "Path to where your model is located.",
|
"modelLocationValidationMsg": "Path to where your model is located.",
|
||||||
|
"repo_id": "Repo ID",
|
||||||
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"vaeLocation": "VAE Location",
|
"vaeLocation": "VAE Location",
|
||||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||||
|
"vaeRepoID": "VAE Repo ID",
|
||||||
|
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
"widthValidationMsg": "Default width of your model.",
|
"widthValidationMsg": "Default width of your model.",
|
||||||
"height": "Height",
|
"height": "Height",
|
||||||
@ -34,6 +44,7 @@
|
|||||||
"checkpointFolder": "Checkpoint Folder",
|
"checkpointFolder": "Checkpoint Folder",
|
||||||
"clearCheckpointFolder": "Clear Checkpoint Folder",
|
"clearCheckpointFolder": "Clear Checkpoint Folder",
|
||||||
"findModels": "Find Models",
|
"findModels": "Find Models",
|
||||||
|
"scanAgain": "Scan Again",
|
||||||
"modelsFound": "Models Found",
|
"modelsFound": "Models Found",
|
||||||
"selectFolder": "Select Folder",
|
"selectFolder": "Select Folder",
|
||||||
"selected": "Selected",
|
"selected": "Selected",
|
||||||
@ -42,9 +53,15 @@
|
|||||||
"showExisting": "Show Existing",
|
"showExisting": "Show Existing",
|
||||||
"addSelected": "Add Selected",
|
"addSelected": "Add Selected",
|
||||||
"modelExists": "Model Exists",
|
"modelExists": "Model Exists",
|
||||||
|
"selectAndAdd": "Select and Add Models Listed Below",
|
||||||
|
"noModelsFound": "No Models Found",
|
||||||
"delete": "Delete",
|
"delete": "Delete",
|
||||||
"deleteModel": "Delete Model",
|
"deleteModel": "Delete Model",
|
||||||
"deleteConfig": "Delete Config",
|
"deleteConfig": "Delete Config",
|
||||||
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
||||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to."
|
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
|
||||||
|
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||||
|
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||||
|
"formMessageDiffusersVAELocation": "VAE Location",
|
||||||
|
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above."
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
{
|
{
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"model": "Model",
|
"model": "Model",
|
||||||
|
"allModels": "All Models",
|
||||||
|
"checkpointModels": "Checkpoints",
|
||||||
|
"diffusersModels": "Diffusers",
|
||||||
|
"safetensorModels": "SafeTensors",
|
||||||
"modelAdded": "Model Added",
|
"modelAdded": "Model Added",
|
||||||
"modelUpdated": "Model Updated",
|
"modelUpdated": "Model Updated",
|
||||||
"modelEntryDeleted": "Model Entry Deleted",
|
"modelEntryDeleted": "Model Entry Deleted",
|
||||||
"cannotUseSpaces": "Cannot Use Spaces",
|
"cannotUseSpaces": "Cannot Use Spaces",
|
||||||
"addNew": "Add New",
|
"addNew": "Add New",
|
||||||
"addNewModel": "Add New Model",
|
"addNewModel": "Add New Model",
|
||||||
|
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
|
||||||
|
"addDiffuserModel": "Add Diffusers",
|
||||||
"addManually": "Add Manually",
|
"addManually": "Add Manually",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
@ -17,8 +23,12 @@
|
|||||||
"configValidationMsg": "Path to the config file of your model.",
|
"configValidationMsg": "Path to the config file of your model.",
|
||||||
"modelLocation": "Model Location",
|
"modelLocation": "Model Location",
|
||||||
"modelLocationValidationMsg": "Path to where your model is located.",
|
"modelLocationValidationMsg": "Path to where your model is located.",
|
||||||
|
"repo_id": "Repo ID",
|
||||||
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"vaeLocation": "VAE Location",
|
"vaeLocation": "VAE Location",
|
||||||
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
"vaeLocationValidationMsg": "Path to where your VAE is located.",
|
||||||
|
"vaeRepoID": "VAE Repo ID",
|
||||||
|
"vaeRepoIDValidationMsg": "Online repository of your VAE",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
"widthValidationMsg": "Default width of your model.",
|
"widthValidationMsg": "Default width of your model.",
|
||||||
"height": "Height",
|
"height": "Height",
|
||||||
@ -49,5 +59,9 @@
|
|||||||
"deleteModel": "Delete Model",
|
"deleteModel": "Delete Model",
|
||||||
"deleteConfig": "Delete Config",
|
"deleteConfig": "Delete Config",
|
||||||
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
||||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to."
|
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
|
||||||
|
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||||
|
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||||
|
"formMessageDiffusersVAELocation": "VAE Location",
|
||||||
|
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above."
|
||||||
}
|
}
|
||||||
|
@ -1 +1,15 @@
|
|||||||
{}
|
{
|
||||||
|
"feature": {
|
||||||
|
"prompt": "Questo è il campo del prompt. Il prompt include oggetti di generazione e termini stilistici. Puoi anche aggiungere il peso (importanza del token) nel prompt, ma i comandi e i parametri dell'interfaccia a linea di comando non funzioneranno.",
|
||||||
|
"gallery": "Galleria visualizza le generazioni dalla cartella degli output man mano che vengono create. Le impostazioni sono memorizzate all'interno di file e accessibili dal menu contestuale.",
|
||||||
|
"other": "Queste opzioni abiliteranno modalità di elaborazione alternative per Invoke. 'Piastrella senza cuciture' creerà modelli ripetuti nell'output. 'Ottimizzzazione Alta risoluzione' è la generazione in due passaggi con 'Immagine a Immagine': usa questa impostazione quando vuoi un'immagine più grande e più coerente senza artefatti. Ci vorrà più tempo del solito 'Testo a Immagine'.",
|
||||||
|
"seed": "Il valore del Seme influenza il rumore iniziale da cui è formata l'immagine. Puoi usare i semi già esistenti dalle immagini precedenti. 'Soglia del rumore' viene utilizzato per mitigare gli artefatti a valori CFG elevati (provare l'intervallo 0-10) e Perlin per aggiungere il rumore Perlin durante la generazione: entrambi servono per aggiungere variazioni ai risultati.",
|
||||||
|
"variations": "Prova una variazione con un valore compreso tra 0.1 e 1.0 per modificare il risultato per un dato seme. Variazioni interessanti del seme sono comprese tra 0.1 e 0.3.",
|
||||||
|
"upscale": "Utilizza ESRGAN per ingrandire l'immagine subito dopo la generazione.",
|
||||||
|
"faceCorrection": "Correzione del volto con GFPGAN o Codeformer: l'algoritmo rileva i volti nell'immagine e corregge eventuali difetti. Un valore alto cambierà maggiormente l'immagine, dando luogo a volti più attraenti. Codeformer con una maggiore fedeltà preserva l'immagine originale a scapito di una correzione facciale più forte.",
|
||||||
|
"imageToImage": "Da Immagine a Immagine carica qualsiasi immagine come iniziale, che viene quindi utilizzata per generarne una nuova in base al prompt. Più alto è il valore, più cambierà l'immagine risultante. Sono possibili valori da 0.0 a 1.0, l'intervallo consigliato è 0.25-0.75",
|
||||||
|
"boundingBox": "Il riquadro di selezione è lo stesso delle impostazioni Larghezza e Altezza per da Testo a Immagine o da Immagine a Immagine. Verrà elaborata solo l'area nella casella.",
|
||||||
|
"seamCorrection": "Controlla la gestione delle giunzioni visibili che si verificano tra le immagini generate sulla tela.",
|
||||||
|
"infillAndScaling": "Gestisce i metodi di riempimento (utilizzati su aree mascherate o cancellate dell'area di disegno) e il ridimensionamento (utile per i riquadri di selezione di piccole dimensioni)."
|
||||||
|
}
|
||||||
|
}
|
@ -1,15 +0,0 @@
|
|||||||
{
|
|
||||||
"feature": {
|
|
||||||
"prompt": "Questo è il campo del prompt. Il prompt include oggetti di generazione e termini stilistici. Puoi anche aggiungere il peso (importanza del token) nel prompt, ma i comandi e i parametri dell'interfaccia a linea di comando non funzioneranno.",
|
|
||||||
"gallery": "Galleria visualizza le generazioni dalla cartella degli output man mano che vengono create. Le impostazioni sono memorizzate all'interno di file e accessibili dal menu contestuale.",
|
|
||||||
"other": "Queste opzioni abiliteranno modalità di elaborazione alternative per Invoke. 'Piastrella senza cuciture' creerà modelli ripetuti nell'output. 'Ottimizzzazione Alta risoluzione' è la generazione in due passaggi con 'Immagine a Immagine': usa questa impostazione quando vuoi un'immagine più grande e più coerente senza artefatti. Ci vorrà più tempo del solito 'Testo a Immagine'.",
|
|
||||||
"seed": "Il valore del Seme influenza il rumore iniziale da cui è formata l'immagine. Puoi usare i semi già esistenti dalle immagini precedenti. 'Soglia del rumore' viene utilizzato per mitigare gli artefatti a valori CFG elevati (provare l'intervallo 0-10) e Perlin per aggiungere il rumore Perlin durante la generazione: entrambi servono per aggiungere variazioni ai risultati.",
|
|
||||||
"variations": "Prova una variazione con un valore compreso tra 0.1 e 1.0 per modificare il risultato per un dato seme. Variazioni interessanti del seme sono comprese tra 0.1 e 0.3.",
|
|
||||||
"upscale": "Utilizza ESRGAN per ingrandire l'immagine subito dopo la generazione.",
|
|
||||||
"faceCorrection": "Correzione del volto con GFPGAN o Codeformer: l'algoritmo rileva i volti nell'immagine e corregge eventuali difetti. Un valore alto cambierà maggiormente l'immagine, dando luogo a volti più attraenti. Codeformer con una maggiore fedeltà preserva l'immagine originale a scapito di una correzione facciale più forte.",
|
|
||||||
"imageToImage": "Da Immagine a Immagine carica qualsiasi immagine come iniziale, che viene quindi utilizzata per generarne una nuova in base al prompt. Più alto è il valore, più cambierà l'immagine risultante. Sono possibili valori da 0.0 a 1.0, l'intervallo consigliato è 0.25-0.75",
|
|
||||||
"boundingBox": "Il riquadro di selezione è lo stesso delle impostazioni Larghezza e Altezza per dat Testo a Immagine o da Immagine a Immagine. Verrà elaborata solo l'area nella casella.",
|
|
||||||
"seamCorrection": "Controlla la gestione delle giunzioni visibili che si verificano tra le immagini generate sulla tela.",
|
|
||||||
"infillAndScaling": "Gestisce i metodi di riempimento (utilizzati su aree mascherate o cancellate dell'area di disegno) e il ridimensionamento (utile per i riquadri di selezione di piccole dimensioni)."
|
|
||||||
}
|
|
||||||
}
|
|
30
frontend/src/app/invokeai.d.ts
vendored
30
frontend/src/app/invokeai.d.ts
vendored
@ -170,9 +170,23 @@ export declare type Model = {
|
|||||||
width?: number;
|
width?: number;
|
||||||
height?: number;
|
height?: number;
|
||||||
default?: boolean;
|
default?: boolean;
|
||||||
|
format?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export declare type ModelList = Record<string, Model>;
|
export declare type DiffusersModel = {
|
||||||
|
status: ModelStatus;
|
||||||
|
description: string;
|
||||||
|
repo_id?: string;
|
||||||
|
path?: string;
|
||||||
|
vae?: {
|
||||||
|
repo_id?: string;
|
||||||
|
path?: string;
|
||||||
|
};
|
||||||
|
format?: string;
|
||||||
|
default?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ModelList = Record<string, Model & DiffusersModel>;
|
||||||
|
|
||||||
export declare type FoundModel = {
|
export declare type FoundModel = {
|
||||||
name: string;
|
name: string;
|
||||||
@ -188,6 +202,20 @@ export declare type InvokeModelConfigProps = {
|
|||||||
width: number | undefined;
|
width: number | undefined;
|
||||||
height: number | undefined;
|
height: number | undefined;
|
||||||
default: boolean | undefined;
|
default: boolean | undefined;
|
||||||
|
format: string | undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type InvokeDiffusersModelConfigProps = {
|
||||||
|
name: string | undefined;
|
||||||
|
description: string | undefined;
|
||||||
|
repo_id: string | undefined;
|
||||||
|
path: string | undefined;
|
||||||
|
default: boolean | undefined;
|
||||||
|
format: string | undefined;
|
||||||
|
vae: {
|
||||||
|
repo_id: string | undefined;
|
||||||
|
path: string | undefined;
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -32,9 +32,9 @@ export const requestSystemConfig = createAction<undefined>(
|
|||||||
|
|
||||||
export const searchForModels = createAction<string>('socketio/searchForModels');
|
export const searchForModels = createAction<string>('socketio/searchForModels');
|
||||||
|
|
||||||
export const addNewModel = createAction<InvokeAI.InvokeModelConfigProps>(
|
export const addNewModel = createAction<
|
||||||
'socketio/addNewModel'
|
InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
|
||||||
);
|
>('socketio/addNewModel');
|
||||||
|
|
||||||
export const deleteModel = createAction<string>('socketio/deleteModel');
|
export const deleteModel = createAction<string>('socketio/deleteModel');
|
||||||
|
|
||||||
|
@ -22,16 +22,16 @@ const layerToDataURL = (
|
|||||||
const { x, y, width, height } = layer.getClientRect();
|
const { x, y, width, height } = layer.getClientRect();
|
||||||
const dataURLBoundingBox = boundingBox
|
const dataURLBoundingBox = boundingBox
|
||||||
? {
|
? {
|
||||||
x: Math.round(boundingBox.x + stageCoordinates.x),
|
x: boundingBox.x + stageCoordinates.x,
|
||||||
y: Math.round(boundingBox.y + stageCoordinates.y),
|
y: boundingBox.y + stageCoordinates.y,
|
||||||
width: Math.round(boundingBox.width),
|
width: boundingBox.width,
|
||||||
height: Math.round(boundingBox.height),
|
height: boundingBox.height,
|
||||||
}
|
}
|
||||||
: {
|
: {
|
||||||
x: Math.round(x),
|
x: x,
|
||||||
y: Math.round(y),
|
y: y,
|
||||||
width: Math.round(width),
|
width: width,
|
||||||
height: Math.round(height),
|
height: height,
|
||||||
};
|
};
|
||||||
|
|
||||||
const dataURL = layer.toDataURL(dataURLBoundingBox);
|
const dataURL = layer.toDataURL(dataURLBoundingBox);
|
||||||
@ -42,10 +42,10 @@ const layerToDataURL = (
|
|||||||
return {
|
return {
|
||||||
dataURL,
|
dataURL,
|
||||||
boundingBox: {
|
boundingBox: {
|
||||||
x: Math.round(relativeClientRect.x),
|
x: relativeClientRect.x,
|
||||||
y: Math.round(relativeClientRect.y),
|
y: relativeClientRect.y,
|
||||||
width: Math.round(width),
|
width: width,
|
||||||
height: Math.round(height),
|
height: height,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -57,6 +57,7 @@ export interface OptionsState {
|
|||||||
width: number;
|
width: number;
|
||||||
shouldUseCanvasBetaLayout: boolean;
|
shouldUseCanvasBetaLayout: boolean;
|
||||||
shouldShowExistingModelsInSearch: boolean;
|
shouldShowExistingModelsInSearch: boolean;
|
||||||
|
addNewModelUIOption: 'ckpt' | 'diffusers' | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const initialOptionsState: OptionsState = {
|
const initialOptionsState: OptionsState = {
|
||||||
@ -105,6 +106,7 @@ const initialOptionsState: OptionsState = {
|
|||||||
width: 512,
|
width: 512,
|
||||||
shouldUseCanvasBetaLayout: false,
|
shouldUseCanvasBetaLayout: false,
|
||||||
shouldShowExistingModelsInSearch: false,
|
shouldShowExistingModelsInSearch: false,
|
||||||
|
addNewModelUIOption: null,
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialState: OptionsState = initialOptionsState;
|
const initialState: OptionsState = initialOptionsState;
|
||||||
@ -412,6 +414,12 @@ export const optionsSlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
state.shouldShowExistingModelsInSearch = action.payload;
|
state.shouldShowExistingModelsInSearch = action.payload;
|
||||||
},
|
},
|
||||||
|
setAddNewModelUIOption: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<'ckpt' | 'diffusers' | null>
|
||||||
|
) => {
|
||||||
|
state.addNewModelUIOption = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -469,6 +477,7 @@ export const {
|
|||||||
setWidth,
|
setWidth,
|
||||||
setShouldUseCanvasBetaLayout,
|
setShouldUseCanvasBetaLayout,
|
||||||
setShouldShowExistingModelsInSearch,
|
setShouldShowExistingModelsInSearch,
|
||||||
|
setAddNewModelUIOption,
|
||||||
} = optionsSlice.actions;
|
} = optionsSlice.actions;
|
||||||
|
|
||||||
export default optionsSlice.reducer;
|
export default optionsSlice.reducer;
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
.modal {
|
.modal {
|
||||||
background-color: var(--background-color-secondary);
|
background-color: var(--background-color-secondary);
|
||||||
color: var(--text-color);
|
color: var(--text-color);
|
||||||
|
font-family: Inter;
|
||||||
}
|
}
|
||||||
|
|
||||||
.modal-close-btn {
|
.modal-close-btn {
|
||||||
|
@ -0,0 +1,328 @@
|
|||||||
|
import {
|
||||||
|
FormControl,
|
||||||
|
FormErrorMessage,
|
||||||
|
FormHelperText,
|
||||||
|
FormLabel,
|
||||||
|
HStack,
|
||||||
|
Text,
|
||||||
|
VStack,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import React from 'react';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
import IAICheckbox from 'common/components/IAICheckbox';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
|
||||||
|
import SearchModels from './SearchModels';
|
||||||
|
|
||||||
|
import { addNewModel } from 'app/socketio/actions';
|
||||||
|
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
|
||||||
|
import { Field, Formik } from 'formik';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { FieldInputProps, FormikProps } from 'formik';
|
||||||
|
import type { RootState } from 'app/store';
|
||||||
|
import type { InvokeModelConfigProps } from 'app/invokeai';
|
||||||
|
import { setAddNewModelUIOption } from 'features/options/store/optionsSlice';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { BiArrowBack } from 'react-icons/bi';
|
||||||
|
|
||||||
|
const MIN_MODEL_SIZE = 64;
|
||||||
|
const MAX_MODEL_SIZE = 2048;
|
||||||
|
|
||||||
|
export default function AddCheckpointModel() {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
|
||||||
|
function hasWhiteSpace(s: string) {
|
||||||
|
return /\s/.test(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
function baseValidation(value: string) {
|
||||||
|
let error;
|
||||||
|
if (hasWhiteSpace(value)) error = t('modelmanager:cannotUseSpaces');
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
const addModelFormValues: InvokeModelConfigProps = {
|
||||||
|
name: '',
|
||||||
|
description: '',
|
||||||
|
config: 'configs/stable-diffusion/v1-inference.yaml',
|
||||||
|
weights: '',
|
||||||
|
vae: '',
|
||||||
|
width: 512,
|
||||||
|
height: 512,
|
||||||
|
format: 'ckpt',
|
||||||
|
default: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||||
|
dispatch(addNewModel(values));
|
||||||
|
dispatch(setAddNewModelUIOption(null));
|
||||||
|
};
|
||||||
|
|
||||||
|
const [addManually, setAddmanually] = React.useState<boolean>(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label={t('common:back')}
|
||||||
|
tooltip={t('common:back')}
|
||||||
|
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||||
|
width="max-content"
|
||||||
|
position="absolute"
|
||||||
|
zIndex={1}
|
||||||
|
size="sm"
|
||||||
|
right={12}
|
||||||
|
top={3}
|
||||||
|
icon={<BiArrowBack />}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SearchModels />
|
||||||
|
<IAICheckbox
|
||||||
|
label={t('modelmanager:addManually')}
|
||||||
|
isChecked={addManually}
|
||||||
|
onChange={() => setAddmanually(!addManually)}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{addManually && (
|
||||||
|
<Formik
|
||||||
|
initialValues={addModelFormValues}
|
||||||
|
onSubmit={addModelFormSubmitHandler}
|
||||||
|
>
|
||||||
|
{({ handleSubmit, errors, touched }) => (
|
||||||
|
<form onSubmit={handleSubmit}>
|
||||||
|
<VStack rowGap={'0.5rem'}>
|
||||||
|
<Text fontSize={20} fontWeight="bold" alignSelf={'start'}>
|
||||||
|
{t('modelmanager:manual')}
|
||||||
|
</Text>
|
||||||
|
{/* Name */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.name && touched.name}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="name" fontSize="sm">
|
||||||
|
{t('modelmanager:name')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="name"
|
||||||
|
name="name"
|
||||||
|
type="text"
|
||||||
|
validate={baseValidation}
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.name && touched.name ? (
|
||||||
|
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:nameValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Description */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.description && touched.description}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="description" fontSize="sm">
|
||||||
|
{t('modelmanager:description')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="description"
|
||||||
|
name="description"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.description && touched.description ? (
|
||||||
|
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:descriptionValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Config */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.config && touched.config}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="config" fontSize="sm">
|
||||||
|
{t('modelmanager:config')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="config"
|
||||||
|
name="config"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.config && touched.config ? (
|
||||||
|
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:configValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Weights */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.weights && touched.weights}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="config" fontSize="sm">
|
||||||
|
{t('modelmanager:modelLocation')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="weights"
|
||||||
|
name="weights"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.weights && touched.weights ? (
|
||||||
|
<FormErrorMessage>{errors.weights}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:modelLocationValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* VAE */}
|
||||||
|
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
||||||
|
<FormLabel htmlFor="vae" fontSize="sm">
|
||||||
|
{t('modelmanager:vaeLocation')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="vae"
|
||||||
|
name="vae"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.vae && touched.vae ? (
|
||||||
|
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:vaeLocationValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
<HStack width={'100%'}>
|
||||||
|
{/* Width */}
|
||||||
|
<FormControl isInvalid={!!errors.width && touched.width}>
|
||||||
|
<FormLabel htmlFor="width" fontSize="sm">
|
||||||
|
{t('modelmanager:width')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field id="width" name="width">
|
||||||
|
{({
|
||||||
|
field,
|
||||||
|
form,
|
||||||
|
}: {
|
||||||
|
field: FieldInputProps<number>;
|
||||||
|
form: FormikProps<InvokeModelConfigProps>;
|
||||||
|
}) => (
|
||||||
|
<IAINumberInput
|
||||||
|
id="width"
|
||||||
|
name="width"
|
||||||
|
min={MIN_MODEL_SIZE}
|
||||||
|
max={MAX_MODEL_SIZE}
|
||||||
|
step={64}
|
||||||
|
width="90%"
|
||||||
|
value={form.values.width}
|
||||||
|
onChange={(value) =>
|
||||||
|
form.setFieldValue(field.name, Number(value))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Field>
|
||||||
|
|
||||||
|
{!!errors.width && touched.width ? (
|
||||||
|
<FormErrorMessage>{errors.width}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:widthValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Height */}
|
||||||
|
<FormControl isInvalid={!!errors.height && touched.height}>
|
||||||
|
<FormLabel htmlFor="height" fontSize="sm">
|
||||||
|
{t('modelmanager:height')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field id="height" name="height">
|
||||||
|
{({
|
||||||
|
field,
|
||||||
|
form,
|
||||||
|
}: {
|
||||||
|
field: FieldInputProps<number>;
|
||||||
|
form: FormikProps<InvokeModelConfigProps>;
|
||||||
|
}) => (
|
||||||
|
<IAINumberInput
|
||||||
|
id="height"
|
||||||
|
name="height"
|
||||||
|
min={MIN_MODEL_SIZE}
|
||||||
|
max={MAX_MODEL_SIZE}
|
||||||
|
width="90%"
|
||||||
|
step={64}
|
||||||
|
value={form.values.height}
|
||||||
|
onChange={(value) =>
|
||||||
|
form.setFieldValue(field.name, Number(value))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Field>
|
||||||
|
|
||||||
|
{!!errors.height && touched.height ? (
|
||||||
|
<FormErrorMessage>{errors.height}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:heightValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
</HStack>
|
||||||
|
|
||||||
|
<IAIButton
|
||||||
|
type="submit"
|
||||||
|
className="modal-close-btn"
|
||||||
|
isLoading={isProcessing}
|
||||||
|
>
|
||||||
|
{t('modelmanager:addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</VStack>
|
||||||
|
</form>
|
||||||
|
)}
|
||||||
|
</Formik>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,310 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormErrorMessage,
|
||||||
|
FormHelperText,
|
||||||
|
FormLabel,
|
||||||
|
Text,
|
||||||
|
VStack,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import { setAddNewModelUIOption } from 'features/options/store/optionsSlice';
|
||||||
|
import { Field, Formik } from 'formik';
|
||||||
|
import React from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { BiArrowBack } from 'react-icons/bi';
|
||||||
|
import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||||
|
import { addNewModel } from 'app/socketio/actions';
|
||||||
|
|
||||||
|
import type { RootState } from 'app/store';
|
||||||
|
import type { ReactElement } from 'react';
|
||||||
|
|
||||||
|
function FormItemWrapper({
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
children: ReactElement | ReactElement[];
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
flexDirection="column"
|
||||||
|
backgroundColor="var(--background-color)"
|
||||||
|
padding="1rem 1rem"
|
||||||
|
borderRadius="0.5rem"
|
||||||
|
rowGap="1rem"
|
||||||
|
width="100%"
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function AddDiffusersModel() {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
|
||||||
|
function hasWhiteSpace(s: string) {
|
||||||
|
return /\s/.test(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
function baseValidation(value: string) {
|
||||||
|
let error;
|
||||||
|
if (hasWhiteSpace(value)) error = t('modelmanager:cannotUseSpaces');
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
const addModelFormValues: InvokeDiffusersModelConfigProps = {
|
||||||
|
name: '',
|
||||||
|
description: '',
|
||||||
|
repo_id: '',
|
||||||
|
path: '',
|
||||||
|
format: 'diffusers',
|
||||||
|
default: false,
|
||||||
|
vae: {
|
||||||
|
repo_id: '',
|
||||||
|
path: '',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const addModelFormSubmitHandler = (
|
||||||
|
values: InvokeDiffusersModelConfigProps
|
||||||
|
) => {
|
||||||
|
const diffusersModelToAdd = values;
|
||||||
|
|
||||||
|
if (values.path === '') diffusersModelToAdd['path'] = undefined;
|
||||||
|
if (values.repo_id === '') diffusersModelToAdd['repo_id'] = undefined;
|
||||||
|
if (values.vae.path === '') {
|
||||||
|
if (values.path === undefined) {
|
||||||
|
diffusersModelToAdd['vae']['path'] = undefined;
|
||||||
|
} else {
|
||||||
|
diffusersModelToAdd['vae']['path'] = values.path + '/vae';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (values.vae.repo_id === '')
|
||||||
|
diffusersModelToAdd['vae']['repo_id'] = undefined;
|
||||||
|
|
||||||
|
dispatch(addNewModel(diffusersModelToAdd));
|
||||||
|
dispatch(setAddNewModelUIOption(null));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex>
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label={t('common:back')}
|
||||||
|
tooltip={t('common:back')}
|
||||||
|
onClick={() => dispatch(setAddNewModelUIOption(null))}
|
||||||
|
width="max-content"
|
||||||
|
position="absolute"
|
||||||
|
zIndex={1}
|
||||||
|
size="sm"
|
||||||
|
right={12}
|
||||||
|
top={3}
|
||||||
|
icon={<BiArrowBack />}
|
||||||
|
/>
|
||||||
|
<Formik
|
||||||
|
initialValues={addModelFormValues}
|
||||||
|
onSubmit={addModelFormSubmitHandler}
|
||||||
|
>
|
||||||
|
{({ handleSubmit, errors, touched }) => (
|
||||||
|
<form onSubmit={handleSubmit}>
|
||||||
|
<VStack rowGap={'0.5rem'}>
|
||||||
|
<FormItemWrapper>
|
||||||
|
{/* Name */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.name && touched.name}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="name" fontSize="sm">
|
||||||
|
{t('modelmanager:name')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="name"
|
||||||
|
name="name"
|
||||||
|
type="text"
|
||||||
|
validate={baseValidation}
|
||||||
|
width="2xl"
|
||||||
|
isRequired
|
||||||
|
/>
|
||||||
|
{!!errors.name && touched.name ? (
|
||||||
|
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:nameValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
</FormItemWrapper>
|
||||||
|
|
||||||
|
<FormItemWrapper>
|
||||||
|
{/* Description */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.description && touched.description}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="description" fontSize="sm">
|
||||||
|
{t('modelmanager:description')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="description"
|
||||||
|
name="description"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
isRequired
|
||||||
|
/>
|
||||||
|
{!!errors.description && touched.description ? (
|
||||||
|
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:descriptionValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
</FormItemWrapper>
|
||||||
|
|
||||||
|
<FormItemWrapper>
|
||||||
|
<Text fontWeight="bold" fontSize="sm">
|
||||||
|
{t('modelmanager:formMessageDiffusersModelLocation')}
|
||||||
|
</Text>
|
||||||
|
<Text
|
||||||
|
fontSize="sm"
|
||||||
|
fontStyle="italic"
|
||||||
|
color="var(--text-color-secondary)"
|
||||||
|
>
|
||||||
|
{t('modelmanager:formMessageDiffusersModelLocationDesc')}
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
{/* Path */}
|
||||||
|
<FormControl isInvalid={!!errors.path && touched.path}>
|
||||||
|
<FormLabel htmlFor="path" fontSize="sm">
|
||||||
|
{t('modelmanager:modelLocation')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="path"
|
||||||
|
name="path"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.path && touched.path ? (
|
||||||
|
<FormErrorMessage>{errors.path}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:modelLocationValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Repo ID */}
|
||||||
|
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
||||||
|
<FormLabel htmlFor="repo_id" fontSize="sm">
|
||||||
|
{t('modelmanager:repo_id')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="repo_id"
|
||||||
|
name="repo_id"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.repo_id && touched.repo_id ? (
|
||||||
|
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:repoIDValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
</FormItemWrapper>
|
||||||
|
|
||||||
|
<FormItemWrapper>
|
||||||
|
{/* VAE Path */}
|
||||||
|
<Text fontWeight="bold">
|
||||||
|
{t('modelmanager:formMessageDiffusersVAELocation')}
|
||||||
|
</Text>
|
||||||
|
<Text
|
||||||
|
fontSize="sm"
|
||||||
|
fontStyle="italic"
|
||||||
|
color="var(--text-color-secondary)"
|
||||||
|
>
|
||||||
|
{t('modelmanager:formMessageDiffusersVAELocationDesc')}
|
||||||
|
</Text>
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="vae.path" fontSize="sm">
|
||||||
|
{t('modelmanager:vaeLocation')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="vae.path"
|
||||||
|
name="vae.path"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.vae?.path && touched.vae?.path ? (
|
||||||
|
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:vaeLocationValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* VAE Repo ID */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
||||||
|
{t('modelmanager:vaeRepoID')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="vae.repo_id"
|
||||||
|
name="vae.repo_id"
|
||||||
|
type="text"
|
||||||
|
width="2xl"
|
||||||
|
/>
|
||||||
|
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
||||||
|
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:vaeRepoIDValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
</FormItemWrapper>
|
||||||
|
|
||||||
|
<IAIButton
|
||||||
|
type="submit"
|
||||||
|
className="modal-close-btn"
|
||||||
|
isLoading={isProcessing}
|
||||||
|
>
|
||||||
|
{t('modelmanager:addModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</VStack>
|
||||||
|
</form>
|
||||||
|
)}
|
||||||
|
</Formik>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -1,10 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
|
||||||
FormErrorMessage,
|
|
||||||
FormHelperText,
|
|
||||||
FormLabel,
|
|
||||||
HStack,
|
|
||||||
Modal,
|
Modal,
|
||||||
ModalBody,
|
ModalBody,
|
||||||
ModalCloseButton,
|
ModalCloseButton,
|
||||||
@ -13,72 +8,64 @@ import {
|
|||||||
ModalOverlay,
|
ModalOverlay,
|
||||||
Text,
|
Text,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
VStack,
|
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
|
||||||
import IAICheckbox from 'common/components/IAICheckbox';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
|
||||||
import SearchModels from './SearchModels';
|
|
||||||
|
|
||||||
import { addNewModel } from 'app/socketio/actions';
|
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
|
||||||
import { FaPlus } from 'react-icons/fa';
|
import { FaPlus } from 'react-icons/fa';
|
||||||
import { Field, Formik } from 'formik';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
|
||||||
import type { FieldInputProps, FormikProps } from 'formik';
|
|
||||||
import type { RootState } from 'app/store';
|
import type { RootState } from 'app/store';
|
||||||
import type { InvokeModelConfigProps } from 'app/invokeai';
|
import { setAddNewModelUIOption } from 'features/options/store/optionsSlice';
|
||||||
|
import AddCheckpointModel from './AddCheckpointModel';
|
||||||
|
import AddDiffusersModel from './AddDiffusersModel';
|
||||||
|
|
||||||
const MIN_MODEL_SIZE = 64;
|
function AddModelBox({
|
||||||
const MAX_MODEL_SIZE = 2048;
|
text,
|
||||||
|
onClick,
|
||||||
|
}: {
|
||||||
|
text: string;
|
||||||
|
onClick?: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
position="relative"
|
||||||
|
width="50%"
|
||||||
|
height="200px"
|
||||||
|
backgroundColor="var(--background-color)"
|
||||||
|
borderRadius="0.5rem"
|
||||||
|
justifyContent="center"
|
||||||
|
alignItems="center"
|
||||||
|
_hover={{
|
||||||
|
cursor: 'pointer',
|
||||||
|
backgroundColor: 'var(--accent-color)',
|
||||||
|
}}
|
||||||
|
onClick={onClick}
|
||||||
|
>
|
||||||
|
<Text fontWeight="bold">{text}</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export default function AddModel() {
|
export default function AddModel() {
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const isProcessing = useAppSelector(
|
const addNewModelUIOption = useAppSelector(
|
||||||
(state: RootState) => state.system.isProcessing
|
(state: RootState) => state.options.addNewModelUIOption
|
||||||
);
|
);
|
||||||
|
|
||||||
function hasWhiteSpace(s: string) {
|
const dispatch = useAppDispatch();
|
||||||
return /\\s/g.test(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
function baseValidation(value: string) {
|
const { t } = useTranslation();
|
||||||
let error;
|
|
||||||
if (hasWhiteSpace(value)) error = t('modelmanager:cannotUseSpaces');
|
|
||||||
return error;
|
|
||||||
}
|
|
||||||
|
|
||||||
const addModelFormValues: InvokeModelConfigProps = {
|
|
||||||
name: '',
|
|
||||||
description: '',
|
|
||||||
config: 'configs/stable-diffusion/v1-inference.yaml',
|
|
||||||
weights: '',
|
|
||||||
vae: '',
|
|
||||||
width: 512,
|
|
||||||
height: 512,
|
|
||||||
default: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
const addModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
|
||||||
dispatch(addNewModel(values));
|
|
||||||
onClose();
|
|
||||||
};
|
|
||||||
|
|
||||||
const addModelModalClose = () => {
|
const addModelModalClose = () => {
|
||||||
onClose();
|
onClose();
|
||||||
|
dispatch(setAddNewModelUIOption(null));
|
||||||
};
|
};
|
||||||
|
|
||||||
const [addManually, setAddmanually] = React.useState<boolean>(false);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
@ -101,266 +88,24 @@ export default function AddModel() {
|
|||||||
closeOnOverlayClick={false}
|
closeOnOverlayClick={false}
|
||||||
>
|
>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent className="modal add-model-modal">
|
<ModalContent className="modal add-model-modal" fontFamily="Inter">
|
||||||
<ModalHeader>{t('modelmanager:addNewModel')}</ModalHeader>
|
<ModalHeader>{t('modelmanager:addNewModel')}</ModalHeader>
|
||||||
<ModalCloseButton />
|
<ModalCloseButton marginTop="0.3rem" />
|
||||||
<ModalBody className="add-model-modal-body">
|
<ModalBody className="add-model-modal-body">
|
||||||
<SearchModels />
|
{addNewModelUIOption == null && (
|
||||||
<IAICheckbox
|
<Flex columnGap="1rem">
|
||||||
label={t('modelmanager:addManually')}
|
<AddModelBox
|
||||||
isChecked={addManually}
|
text={t('modelmanager:addCheckpointModel')}
|
||||||
onChange={() => setAddmanually(!addManually)}
|
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
||||||
/>
|
/>
|
||||||
|
<AddModelBox
|
||||||
{addManually && (
|
text={t('modelmanager:addDiffuserModel')}
|
||||||
<Formik
|
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
||||||
initialValues={addModelFormValues}
|
|
||||||
onSubmit={addModelFormSubmitHandler}
|
|
||||||
>
|
|
||||||
{({ handleSubmit, errors, touched }) => (
|
|
||||||
<form onSubmit={handleSubmit}>
|
|
||||||
<VStack rowGap={'0.5rem'}>
|
|
||||||
<Text fontSize={20} fontWeight="bold" alignSelf={'start'}>
|
|
||||||
{t('modelmanager:manual')}
|
|
||||||
</Text>
|
|
||||||
{/* Name */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.name && touched.name}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="name" fontSize="sm">
|
|
||||||
{t('modelmanager:name')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="name"
|
|
||||||
name="name"
|
|
||||||
type="text"
|
|
||||||
validate={baseValidation}
|
|
||||||
width="2xl"
|
|
||||||
/>
|
/>
|
||||||
{!!errors.name && touched.name ? (
|
</Flex>
|
||||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:nameValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Description */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.description && touched.description}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="description" fontSize="sm">
|
|
||||||
{t('modelmanager:description')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="description"
|
|
||||||
name="description"
|
|
||||||
type="text"
|
|
||||||
width="2xl"
|
|
||||||
/>
|
|
||||||
{!!errors.description && touched.description ? (
|
|
||||||
<FormErrorMessage>
|
|
||||||
{errors.description}
|
|
||||||
</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:descriptionValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Config */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.config && touched.config}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelmanager:config')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="config"
|
|
||||||
name="config"
|
|
||||||
type="text"
|
|
||||||
width="2xl"
|
|
||||||
/>
|
|
||||||
{!!errors.config && touched.config ? (
|
|
||||||
<FormErrorMessage>{errors.config}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:configValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Weights */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.weights && touched.weights}
|
|
||||||
isRequired
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="config" fontSize="sm">
|
|
||||||
{t('modelmanager:modelLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="weights"
|
|
||||||
name="weights"
|
|
||||||
type="text"
|
|
||||||
width="2xl"
|
|
||||||
/>
|
|
||||||
{!!errors.weights && touched.weights ? (
|
|
||||||
<FormErrorMessage>
|
|
||||||
{errors.weights}
|
|
||||||
</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:modelLocationValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* VAE */}
|
|
||||||
<FormControl isInvalid={!!errors.vae && touched.vae}>
|
|
||||||
<FormLabel htmlFor="vae" fontSize="sm">
|
|
||||||
{t('modelmanager:vaeLocation')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field
|
|
||||||
as={IAIInput}
|
|
||||||
id="vae"
|
|
||||||
name="vae"
|
|
||||||
type="text"
|
|
||||||
width="2xl"
|
|
||||||
/>
|
|
||||||
{!!errors.vae && touched.vae ? (
|
|
||||||
<FormErrorMessage>{errors.vae}</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:vaeLocationValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<HStack width={'100%'}>
|
|
||||||
{/* Width */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.width && touched.width}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="width" fontSize="sm">
|
|
||||||
{t('modelmanager:width')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field id="width" name="width">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="width"
|
|
||||||
name="width"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
step={64}
|
|
||||||
width="90%"
|
|
||||||
value={form.values.width}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(
|
|
||||||
field.name,
|
|
||||||
Number(value)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.width && touched.width ? (
|
|
||||||
<FormErrorMessage>
|
|
||||||
{errors.width}
|
|
||||||
</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:widthValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
{/* Height */}
|
|
||||||
<FormControl
|
|
||||||
isInvalid={!!errors.height && touched.height}
|
|
||||||
>
|
|
||||||
<FormLabel htmlFor="height" fontSize="sm">
|
|
||||||
{t('modelmanager:height')}
|
|
||||||
</FormLabel>
|
|
||||||
<VStack alignItems={'start'}>
|
|
||||||
<Field id="height" name="height">
|
|
||||||
{({
|
|
||||||
field,
|
|
||||||
form,
|
|
||||||
}: {
|
|
||||||
field: FieldInputProps<number>;
|
|
||||||
form: FormikProps<InvokeModelConfigProps>;
|
|
||||||
}) => (
|
|
||||||
<IAINumberInput
|
|
||||||
id="height"
|
|
||||||
name="height"
|
|
||||||
min={MIN_MODEL_SIZE}
|
|
||||||
max={MAX_MODEL_SIZE}
|
|
||||||
width="90%"
|
|
||||||
step={64}
|
|
||||||
value={form.values.height}
|
|
||||||
onChange={(value) =>
|
|
||||||
form.setFieldValue(
|
|
||||||
field.name,
|
|
||||||
Number(value)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{!!errors.height && touched.height ? (
|
|
||||||
<FormErrorMessage>
|
|
||||||
{errors.height}
|
|
||||||
</FormErrorMessage>
|
|
||||||
) : (
|
|
||||||
<FormHelperText margin={0}>
|
|
||||||
{t('modelmanager:heightValidationMsg')}
|
|
||||||
</FormHelperText>
|
|
||||||
)}
|
|
||||||
</VStack>
|
|
||||||
</FormControl>
|
|
||||||
</HStack>
|
|
||||||
|
|
||||||
<IAIButton
|
|
||||||
type="submit"
|
|
||||||
className="modal-close-btn"
|
|
||||||
isLoading={isProcessing}
|
|
||||||
>
|
|
||||||
{t('modelmanager:addModel')}
|
|
||||||
</IAIButton>
|
|
||||||
</VStack>
|
|
||||||
</form>
|
|
||||||
)}
|
|
||||||
</Formik>
|
|
||||||
)}
|
)}
|
||||||
|
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
||||||
|
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
||||||
</ModalBody>
|
</ModalBody>
|
||||||
</ModalContent>
|
</ModalContent>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
@ -48,7 +48,7 @@ const selector = createSelector(
|
|||||||
const MIN_MODEL_SIZE = 64;
|
const MIN_MODEL_SIZE = 64;
|
||||||
const MAX_MODEL_SIZE = 2048;
|
const MAX_MODEL_SIZE = 2048;
|
||||||
|
|
||||||
export default function ModelEdit() {
|
export default function CheckpointModelEdit() {
|
||||||
const { openModel, model_list } = useAppSelector(selector);
|
const { openModel, model_list } = useAppSelector(selector);
|
||||||
const isProcessing = useAppSelector(
|
const isProcessing = useAppSelector(
|
||||||
(state: RootState) => state.system.isProcessing
|
(state: RootState) => state.system.isProcessing
|
||||||
@ -68,6 +68,7 @@ export default function ModelEdit() {
|
|||||||
width: 512,
|
width: 512,
|
||||||
height: 512,
|
height: 512,
|
||||||
default: false,
|
default: false,
|
||||||
|
format: 'ckpt',
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -84,12 +85,19 @@ export default function ModelEdit() {
|
|||||||
width: retrievedModel[openModel]?.width,
|
width: retrievedModel[openModel]?.width,
|
||||||
height: retrievedModel[openModel]?.height,
|
height: retrievedModel[openModel]?.height,
|
||||||
default: retrievedModel[openModel]?.default,
|
default: retrievedModel[openModel]?.default,
|
||||||
|
format: 'ckpt',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [model_list, openModel]);
|
}, [model_list, openModel]);
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
const editModelFormSubmitHandler = (values: InvokeModelConfigProps) => {
|
||||||
dispatch(addNewModel(values));
|
dispatch(
|
||||||
|
addNewModel({
|
||||||
|
...values,
|
||||||
|
width: Number(values.width),
|
||||||
|
height: Number(values.height),
|
||||||
|
})
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return openModel ? (
|
return openModel ? (
|
@ -0,0 +1,270 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
|
import React, { useEffect, useState } from 'react';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
|
||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormErrorMessage,
|
||||||
|
FormHelperText,
|
||||||
|
FormLabel,
|
||||||
|
Text,
|
||||||
|
VStack,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { Field, Formik } from 'formik';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { addNewModel } from 'app/socketio/actions';
|
||||||
|
|
||||||
|
import _ from 'lodash';
|
||||||
|
|
||||||
|
import type { RootState } from 'app/store';
|
||||||
|
import type { InvokeDiffusersModelConfigProps } from 'app/invokeai';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[systemSelector],
|
||||||
|
(system) => {
|
||||||
|
const { openModel, model_list } = system;
|
||||||
|
return {
|
||||||
|
model_list,
|
||||||
|
openModel,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: _.isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export default function DiffusersModelEdit() {
|
||||||
|
const { openModel, model_list } = useAppSelector(selector);
|
||||||
|
const isProcessing = useAppSelector(
|
||||||
|
(state: RootState) => state.system.isProcessing
|
||||||
|
);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const [editModelFormValues, setEditModelFormValues] =
|
||||||
|
useState<InvokeDiffusersModelConfigProps>({
|
||||||
|
name: '',
|
||||||
|
description: '',
|
||||||
|
repo_id: '',
|
||||||
|
path: '',
|
||||||
|
vae: { repo_id: '', path: '' },
|
||||||
|
default: false,
|
||||||
|
format: 'diffusers',
|
||||||
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (openModel) {
|
||||||
|
const retrievedModel = _.pickBy(model_list, (val, key) => {
|
||||||
|
return _.isEqual(key, openModel);
|
||||||
|
});
|
||||||
|
|
||||||
|
setEditModelFormValues({
|
||||||
|
name: openModel,
|
||||||
|
description: retrievedModel[openModel]?.description,
|
||||||
|
path: retrievedModel[openModel]?.path,
|
||||||
|
repo_id: retrievedModel[openModel]?.repo_id,
|
||||||
|
vae: {
|
||||||
|
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
||||||
|
? retrievedModel[openModel]?.vae?.repo_id
|
||||||
|
: '',
|
||||||
|
path: retrievedModel[openModel]?.vae?.path
|
||||||
|
? retrievedModel[openModel]?.vae?.path
|
||||||
|
: '',
|
||||||
|
},
|
||||||
|
default: retrievedModel[openModel]?.default,
|
||||||
|
format: 'diffusers',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [model_list, openModel]);
|
||||||
|
|
||||||
|
const editModelFormSubmitHandler = (
|
||||||
|
values: InvokeDiffusersModelConfigProps
|
||||||
|
) => {
|
||||||
|
dispatch(addNewModel(values));
|
||||||
|
};
|
||||||
|
|
||||||
|
return openModel ? (
|
||||||
|
<Flex flexDirection="column" rowGap="1rem" width="100%">
|
||||||
|
<Flex alignItems="center">
|
||||||
|
<Text fontSize="lg" fontWeight="bold">
|
||||||
|
{openModel}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<Flex
|
||||||
|
flexDirection="column"
|
||||||
|
maxHeight={window.innerHeight - 270}
|
||||||
|
overflowY="scroll"
|
||||||
|
paddingRight="2rem"
|
||||||
|
>
|
||||||
|
<Formik
|
||||||
|
enableReinitialize={true}
|
||||||
|
initialValues={editModelFormValues}
|
||||||
|
onSubmit={editModelFormSubmitHandler}
|
||||||
|
>
|
||||||
|
{({ handleSubmit, errors, touched }) => (
|
||||||
|
<form onSubmit={handleSubmit}>
|
||||||
|
<VStack rowGap={'0.5rem'} alignItems="start">
|
||||||
|
{/* Description */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.description && touched.description}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="description" fontSize="sm">
|
||||||
|
{t('modelmanager:description')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="description"
|
||||||
|
name="description"
|
||||||
|
type="text"
|
||||||
|
width="lg"
|
||||||
|
/>
|
||||||
|
{!!errors.description && touched.description ? (
|
||||||
|
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:descriptionValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Path */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.path && touched.path}
|
||||||
|
isRequired
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="path" fontSize="sm">
|
||||||
|
{t('modelmanager:modelLocation')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="path"
|
||||||
|
name="path"
|
||||||
|
type="text"
|
||||||
|
width="lg"
|
||||||
|
/>
|
||||||
|
{!!errors.path && touched.path ? (
|
||||||
|
<FormErrorMessage>{errors.path}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:modelLocationValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* Repo ID */}
|
||||||
|
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
||||||
|
<FormLabel htmlFor="repo_id" fontSize="sm">
|
||||||
|
{t('modelmanager:repo_id')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="repo_id"
|
||||||
|
name="repo_id"
|
||||||
|
type="text"
|
||||||
|
width="lg"
|
||||||
|
/>
|
||||||
|
{!!errors.repo_id && touched.repo_id ? (
|
||||||
|
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:repoIDValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* VAE Path */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="vae.path" fontSize="sm">
|
||||||
|
{t('modelmanager:vaeLocation')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="vae.path"
|
||||||
|
name="vae.path"
|
||||||
|
type="text"
|
||||||
|
width="lg"
|
||||||
|
/>
|
||||||
|
{!!errors.vae?.path && touched.vae?.path ? (
|
||||||
|
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:vaeLocationValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{/* VAE Repo ID */}
|
||||||
|
<FormControl
|
||||||
|
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
||||||
|
>
|
||||||
|
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
||||||
|
{t('modelmanager:vaeRepoID')}
|
||||||
|
</FormLabel>
|
||||||
|
<VStack alignItems={'start'}>
|
||||||
|
<Field
|
||||||
|
as={IAIInput}
|
||||||
|
id="vae.repo_id"
|
||||||
|
name="vae.repo_id"
|
||||||
|
type="text"
|
||||||
|
width="lg"
|
||||||
|
/>
|
||||||
|
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
||||||
|
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
|
||||||
|
) : (
|
||||||
|
<FormHelperText margin={0}>
|
||||||
|
{t('modelmanager:vaeRepoIDValidationMsg')}
|
||||||
|
</FormHelperText>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
<IAIButton
|
||||||
|
type="submit"
|
||||||
|
className="modal-close-btn"
|
||||||
|
isLoading={isProcessing}
|
||||||
|
>
|
||||||
|
{t('modelmanager:updateModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</VStack>
|
||||||
|
</form>
|
||||||
|
)}
|
||||||
|
</Formik>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<Flex
|
||||||
|
width="100%"
|
||||||
|
height="250px"
|
||||||
|
justifyContent="center"
|
||||||
|
alignItems="center"
|
||||||
|
backgroundColor="var(--background-color)"
|
||||||
|
borderRadius="0.5rem"
|
||||||
|
>
|
||||||
|
<Text fontWeight="bold" color="var(--subtext-color-bright)">
|
||||||
|
Pick A Model To Edit
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
import { useState } from 'react';
|
import React, { useState, useTransition, useMemo } from 'react';
|
||||||
import { Flex, Text } from '@chakra-ui/react';
|
import { Box, Flex, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ import _ from 'lodash';
|
|||||||
import type { ChangeEvent, ReactNode } from 'react';
|
import type { ChangeEvent, ReactNode } from 'react';
|
||||||
import type { RootState } from 'app/store';
|
import type { RootState } from 'app/store';
|
||||||
import type { SystemState } from 'features/system/store/systemSlice';
|
import type { SystemState } from 'features/system/store/systemSlice';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
|
||||||
const modelListSelector = createSelector(
|
const modelListSelector = createSelector(
|
||||||
(state: RootState) => state.system,
|
(state: RootState) => state.system,
|
||||||
@ -21,33 +22,64 @@ const modelListSelector = createSelector(
|
|||||||
const models = _.map(system.model_list, (model, key) => {
|
const models = _.map(system.model_list, (model, key) => {
|
||||||
return { name: key, ...model };
|
return { name: key, ...model };
|
||||||
});
|
});
|
||||||
|
return models;
|
||||||
const activeModel = models.find((model) => model.status === 'active');
|
},
|
||||||
|
{
|
||||||
return {
|
memoizeOptions: {
|
||||||
models,
|
resultEqualityCheck: _.isEqual,
|
||||||
activeModel: activeModel,
|
},
|
||||||
};
|
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
function ModelFilterButton({
|
||||||
|
label,
|
||||||
|
isActive,
|
||||||
|
onClick,
|
||||||
|
}: {
|
||||||
|
label: string;
|
||||||
|
isActive: boolean;
|
||||||
|
onClick: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<IAIButton
|
||||||
|
onClick={onClick}
|
||||||
|
isActive={isActive}
|
||||||
|
_active={{
|
||||||
|
backgroundColor: 'var(--accent-color)',
|
||||||
|
_hover: { backgroundColor: 'var(--accent-color)' },
|
||||||
|
}}
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</IAIButton>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { models } = useAppSelector(modelListSelector);
|
const models = useAppSelector(modelListSelector);
|
||||||
|
|
||||||
const [searchText, setSearchText] = useState<string>('');
|
const [searchText, setSearchText] = useState<string>('');
|
||||||
|
const [isSelectedFilter, setIsSelectedFilter] = useState<
|
||||||
|
'all' | 'ckpt' | 'diffusers'
|
||||||
|
>('all');
|
||||||
|
const [_, startTransition] = useTransition();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleSearchFilter = _.debounce((e: ChangeEvent<HTMLInputElement>) => {
|
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
startTransition(() => {
|
||||||
setSearchText(e.target.value);
|
setSearchText(e.target.value);
|
||||||
}, 400);
|
});
|
||||||
|
};
|
||||||
|
|
||||||
const renderModelListItems = () => {
|
const renderModelListItems = useMemo(() => {
|
||||||
const modelListItemsToRender: ReactNode[] = [];
|
const ckptModelListItemsToRender: ReactNode[] = [];
|
||||||
|
const diffusersModelListItemsToRender: ReactNode[] = [];
|
||||||
const filteredModelListItemsToRender: ReactNode[] = [];
|
const filteredModelListItemsToRender: ReactNode[] = [];
|
||||||
|
const localFilteredModelListItemsToRender: ReactNode[] = [];
|
||||||
|
|
||||||
models.forEach((model, i) => {
|
models.forEach((model, i) => {
|
||||||
if (model.name.startsWith(searchText)) {
|
if (model.name.toLowerCase().includes(searchText.toLowerCase())) {
|
||||||
filteredModelListItemsToRender.push(
|
filteredModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
@ -56,8 +88,8 @@ const ModelList = () => {
|
|||||||
description={model.description}
|
description={model.description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
if (model.format === isSelectedFilter) {
|
||||||
modelListItemsToRender.push(
|
localFilteredModelListItemsToRender.push(
|
||||||
<ModelListItem
|
<ModelListItem
|
||||||
key={i}
|
key={i}
|
||||||
name={model.name}
|
name={model.name}
|
||||||
@ -65,12 +97,84 @@ const ModelList = () => {
|
|||||||
description={model.description}
|
description={model.description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (model.format !== 'diffusers') {
|
||||||
|
ckptModelListItemsToRender.push(
|
||||||
|
<ModelListItem
|
||||||
|
key={i}
|
||||||
|
name={model.name}
|
||||||
|
status={model.status}
|
||||||
|
description={model.description}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
diffusersModelListItemsToRender.push(
|
||||||
|
<ModelListItem
|
||||||
|
key={i}
|
||||||
|
name={model.name}
|
||||||
|
status={model.status}
|
||||||
|
description={model.description}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return searchText !== ''
|
return searchText !== '' ? (
|
||||||
? filteredModelListItemsToRender
|
isSelectedFilter === 'all' ? (
|
||||||
: modelListItemsToRender;
|
<Box marginTop="1rem">{filteredModelListItemsToRender}</Box>
|
||||||
};
|
) : (
|
||||||
|
<Box marginTop="1rem">{localFilteredModelListItemsToRender}</Box>
|
||||||
|
)
|
||||||
|
) : (
|
||||||
|
<Flex flexDirection="column" rowGap="1.5rem">
|
||||||
|
{isSelectedFilter === 'all' && (
|
||||||
|
<>
|
||||||
|
<Box>
|
||||||
|
<Text
|
||||||
|
fontWeight="bold"
|
||||||
|
backgroundColor="var(--background-color)"
|
||||||
|
padding="0.5rem 1rem"
|
||||||
|
borderRadius="0.5rem"
|
||||||
|
margin="1rem 0"
|
||||||
|
width="max-content"
|
||||||
|
fontSize="14"
|
||||||
|
>
|
||||||
|
{t('modelmanager:checkpointModels')}
|
||||||
|
</Text>
|
||||||
|
{ckptModelListItemsToRender}
|
||||||
|
</Box>
|
||||||
|
<Box>
|
||||||
|
<Text
|
||||||
|
fontWeight="bold"
|
||||||
|
backgroundColor="var(--background-color)"
|
||||||
|
padding="0.5rem 1rem"
|
||||||
|
borderRadius="0.5rem"
|
||||||
|
marginBottom="0.5rem"
|
||||||
|
width="max-content"
|
||||||
|
fontSize="14"
|
||||||
|
>
|
||||||
|
{t('modelmanager:diffusersModels')}
|
||||||
|
</Text>
|
||||||
|
{diffusersModelListItemsToRender}
|
||||||
|
</Box>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isSelectedFilter === 'ckpt' && (
|
||||||
|
<Flex flexDirection="column" marginTop="1rem">
|
||||||
|
{ckptModelListItemsToRender}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{isSelectedFilter === 'diffusers' && (
|
||||||
|
<Flex flexDirection="column" marginTop="1rem">
|
||||||
|
{diffusersModelListItemsToRender}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}, [models, searchText, t, isSelectedFilter]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDirection={'column'} rowGap="2rem" width="50%" minWidth="50%">
|
<Flex flexDirection={'column'} rowGap="2rem" width="50%" minWidth="50%">
|
||||||
@ -93,7 +197,24 @@ const ModelList = () => {
|
|||||||
overflow={'scroll'}
|
overflow={'scroll'}
|
||||||
paddingRight="1rem"
|
paddingRight="1rem"
|
||||||
>
|
>
|
||||||
{renderModelListItems()}
|
<Flex columnGap="0.5rem">
|
||||||
|
<ModelFilterButton
|
||||||
|
label={t('modelmanager:allModels')}
|
||||||
|
onClick={() => setIsSelectedFilter('all')}
|
||||||
|
isActive={isSelectedFilter === 'all'}
|
||||||
|
/>
|
||||||
|
<ModelFilterButton
|
||||||
|
label={t('modelmanager:checkpointModels')}
|
||||||
|
onClick={() => setIsSelectedFilter('ckpt')}
|
||||||
|
isActive={isSelectedFilter === 'ckpt'}
|
||||||
|
/>
|
||||||
|
<ModelFilterButton
|
||||||
|
label={t('modelmanager:diffusersModels')}
|
||||||
|
onClick={() => setIsSelectedFilter('diffusers')}
|
||||||
|
isActive={isSelectedFilter === 'diffusers'}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
{renderModelListItems}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -8,13 +8,17 @@ import {
|
|||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import React, { cloneElement } from 'react';
|
import React, { cloneElement } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import ModelEdit from './ModelEdit';
|
import { useTranslation } from 'react-i18next';
|
||||||
import ModelList from './ModelList';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
|
||||||
import type { ReactElement } from 'react';
|
import type { ReactElement } from 'react';
|
||||||
|
|
||||||
|
import ModelList from './ModelList';
|
||||||
|
import DiffusersModelEdit from './DiffusersModelEdit';
|
||||||
|
import CheckpointModelEdit from './CheckpointModelEdit';
|
||||||
|
|
||||||
type ModelManagerModalProps = {
|
type ModelManagerModalProps = {
|
||||||
children: ReactElement;
|
children: ReactElement;
|
||||||
};
|
};
|
||||||
@ -28,6 +32,14 @@ export default function ModelManagerModal({
|
|||||||
onClose: onModelManagerModalClose,
|
onClose: onModelManagerModalClose,
|
||||||
} = useDisclosure();
|
} = useDisclosure();
|
||||||
|
|
||||||
|
const model_list = useAppSelector(
|
||||||
|
(state: RootState) => state.system.model_list
|
||||||
|
);
|
||||||
|
|
||||||
|
const openModel = useAppSelector(
|
||||||
|
(state: RootState) => state.system.openModel
|
||||||
|
);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -41,16 +53,22 @@ export default function ModelManagerModal({
|
|||||||
size="6xl"
|
size="6xl"
|
||||||
>
|
>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent className=" modal">
|
<ModalContent className="modal" fontFamily="Inter">
|
||||||
<ModalCloseButton className="modal-close-btn" />
|
<ModalCloseButton className="modal-close-btn" />
|
||||||
<ModalHeader>{t('modelmanager:modelManager')}</ModalHeader>
|
<ModalHeader fontWeight="bold">
|
||||||
|
{t('modelmanager:modelManager')}
|
||||||
|
</ModalHeader>
|
||||||
<Flex
|
<Flex
|
||||||
padding={'0 1.5rem 1.5rem 1.5rem'}
|
padding={'0 1.5rem 1.5rem 1.5rem'}
|
||||||
width="100%"
|
width="100%"
|
||||||
columnGap={'2rem'}
|
columnGap={'2rem'}
|
||||||
>
|
>
|
||||||
<ModelList />
|
<ModelList />
|
||||||
<ModelEdit />
|
{openModel && model_list[openModel]['format'] === 'diffusers' ? (
|
||||||
|
<DiffusersModelEdit />
|
||||||
|
) : (
|
||||||
|
<CheckpointModelEdit />
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ModalContent>
|
</ModalContent>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
@ -178,6 +178,7 @@ export default function SearchModels() {
|
|||||||
width: 512,
|
width: 512,
|
||||||
height: 512,
|
height: 512,
|
||||||
default: false,
|
default: false,
|
||||||
|
format: 'ckpt',
|
||||||
};
|
};
|
||||||
dispatch(addNewModel(modelFormat));
|
dispatch(addNewModel(modelFormat));
|
||||||
});
|
});
|
||||||
|
@ -13,10 +13,6 @@
|
|||||||
width: 32px;
|
width: 32px;
|
||||||
height: 32px;
|
height: 32px;
|
||||||
}
|
}
|
||||||
|
|
||||||
h1 {
|
|
||||||
font-size: 1.4rem;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.site-header-right-side {
|
.site-header-right-side {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Link } from '@chakra-ui/react';
|
import { Flex, Link, Text } from '@chakra-ui/react';
|
||||||
|
|
||||||
import { FaGithub, FaDiscord, FaBug, FaKeyboard, FaCube } from 'react-icons/fa';
|
import { FaGithub, FaDiscord, FaBug, FaKeyboard, FaCube } from 'react-icons/fa';
|
||||||
|
|
||||||
@ -17,20 +17,34 @@ import LanguagePicker from './LanguagePicker';
|
|||||||
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { MdSettings } from 'react-icons/md';
|
import { MdSettings } from 'react-icons/md';
|
||||||
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import type { RootState } from 'app/store';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Header, includes color mode toggle, settings button, status message.
|
* Header, includes color mode toggle, settings button, status message.
|
||||||
*/
|
*/
|
||||||
const SiteHeader = () => {
|
const SiteHeader = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const appVersion = useAppSelector(
|
||||||
|
(state: RootState) => state.system.app_version
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="site-header">
|
<div className="site-header">
|
||||||
<div className="site-header-left-side">
|
<div className="site-header-left-side">
|
||||||
<img src={InvokeAILogo} alt="invoke-ai-logo" />
|
<img src={InvokeAILogo} alt="invoke-ai-logo" />
|
||||||
<h1>
|
<Flex alignItems="center" columnGap="0.6rem">
|
||||||
|
<Text fontSize="1.4rem">
|
||||||
invoke <strong>ai</strong>
|
invoke <strong>ai</strong>
|
||||||
</h1>
|
</Text>
|
||||||
|
<Text
|
||||||
|
fontWeight="bold"
|
||||||
|
color="var(--text-color-secondary)"
|
||||||
|
marginTop="0.2rem"
|
||||||
|
>
|
||||||
|
{appVersion}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="site-header-right-side">
|
<div className="site-header-right-side">
|
||||||
|
@ -9,8 +9,11 @@ set INVOKEAI_ROOT=.
|
|||||||
echo Do you want to generate images using the
|
echo Do you want to generate images using the
|
||||||
echo 1. command-line
|
echo 1. command-line
|
||||||
echo 2. browser-based UI
|
echo 2. browser-based UI
|
||||||
echo 3. open the developer console
|
echo 3. run textual inversion training
|
||||||
set /P restore="Please enter 1, 2 or 3: "
|
echo 4. open the developer console
|
||||||
|
echo 5. re-run the configure script to download new models
|
||||||
|
set /P restore="Please enter 1, 2, 3, 4 or 5: [5] "
|
||||||
|
if not defined restore set restore=2
|
||||||
IF /I "%restore%" == "1" (
|
IF /I "%restore%" == "1" (
|
||||||
echo Starting the InvokeAI command-line..
|
echo Starting the InvokeAI command-line..
|
||||||
python .venv\Scripts\invoke.py %*
|
python .venv\Scripts\invoke.py %*
|
||||||
@ -18,6 +21,9 @@ IF /I "%restore%" == "1" (
|
|||||||
echo Starting the InvokeAI browser-based UI..
|
echo Starting the InvokeAI browser-based UI..
|
||||||
python .venv\Scripts\invoke.py --web %*
|
python .venv\Scripts\invoke.py --web %*
|
||||||
) ELSE IF /I "%restore%" == "3" (
|
) ELSE IF /I "%restore%" == "3" (
|
||||||
|
echo Starting textual inversion training..
|
||||||
|
python .venv\Scripts\textual_inversion_fe.py --web %*
|
||||||
|
) ELSE IF /I "%restore%" == "4" (
|
||||||
echo Developer Console
|
echo Developer Console
|
||||||
echo Python command is:
|
echo Python command is:
|
||||||
where python
|
where python
|
||||||
@ -29,6 +35,9 @@ IF /I "%restore%" == "1" (
|
|||||||
echo *************************
|
echo *************************
|
||||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||||
call cmd /k
|
call cmd /k
|
||||||
|
) ELSE IF /I "%restore%" == "5" (
|
||||||
|
echo Running configure_invokeai.py...
|
||||||
|
python .venv\Scripts\configure_invokeai.py --web %*
|
||||||
) ELSE (
|
) ELSE (
|
||||||
echo Invalid selection
|
echo Invalid selection
|
||||||
pause
|
pause
|
||||||
|
@ -19,12 +19,17 @@ if [ "$0" != "bash" ]; then
|
|||||||
echo "Do you want to generate images using the"
|
echo "Do you want to generate images using the"
|
||||||
echo "1. command-line"
|
echo "1. command-line"
|
||||||
echo "2. browser-based UI"
|
echo "2. browser-based UI"
|
||||||
echo "3. open the developer console"
|
echo "3. run textual inversion training"
|
||||||
read -p "Please enter 1, 2, or 3: " yn
|
echo "4. open the developer console"
|
||||||
case $yn in
|
echo "5. re-run the configure script to download new models"
|
||||||
|
read -p "Please enter 1, 2, 3, 4 or 5: [1] " yn
|
||||||
|
choice=${yn:='2'}
|
||||||
|
case $choice in
|
||||||
1 ) printf "\nStarting the InvokeAI command-line..\n"; .venv/bin/python .venv/bin/invoke.py $*;;
|
1 ) printf "\nStarting the InvokeAI command-line..\n"; .venv/bin/python .venv/bin/invoke.py $*;;
|
||||||
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; .venv/bin/python .venv/bin/invoke.py --web $*;;
|
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; .venv/bin/python .venv/bin/invoke.py --web $*;;
|
||||||
3 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
3 ) printf "\nStarting Textual Inversion:\n"; .venv/bin/python .venv/bin/textual_inversion_fe.py $*;;
|
||||||
|
4 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
||||||
|
5 ) printf "\nRunning configure_invokeai.py:\n"; .venv/bin/python .venv/bin/configure_invokeai.py $*;;
|
||||||
* ) echo "Invalid selection"; exit;;
|
* ) echo "Invalid selection"; exit;;
|
||||||
esac
|
esac
|
||||||
else # in developer console
|
else # in developer console
|
||||||
|
@ -23,6 +23,7 @@ if "%arg%" neq "" (
|
|||||||
|
|
||||||
set INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/!INVOKE_AI_VERSION!.zip"
|
set INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/!INVOKE_AI_VERSION!.zip"
|
||||||
set INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/!INVOKE_AI_VERSION!/environments-and-requirements/requirements-base.txt
|
set INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/!INVOKE_AI_VERSION!/environments-and-requirements/requirements-base.txt
|
||||||
|
set INVOKE_AI_MODELS=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/configs/INITIAL_MODELS.yaml
|
||||||
|
|
||||||
call curl -I "%INVOKE_AI_DEP%" -fs >.tmp.out
|
call curl -I "%INVOKE_AI_DEP%" -fs >.tmp.out
|
||||||
if %errorlevel% neq 0 (
|
if %errorlevel% neq 0 (
|
||||||
@ -38,6 +39,8 @@ echo If you do not want to do this, press control-C now!
|
|||||||
pause
|
pause
|
||||||
|
|
||||||
call curl -L "%INVOKE_AI_DEP%" > environments-and-requirements/requirements-base.txt
|
call curl -L "%INVOKE_AI_DEP%" > environments-and-requirements/requirements-base.txt
|
||||||
|
call curl -L "%INVOKE_AI_MODELS%" > configs/INITIAL_MODELS.yaml
|
||||||
|
|
||||||
|
|
||||||
call .venv\Scripts\activate.bat
|
call .venv\Scripts\activate.bat
|
||||||
call .venv\Scripts\python -mpip install -r requirements.txt
|
call .venv\Scripts\python -mpip install -r requirements.txt
|
||||||
|
@ -18,6 +18,7 @@ INVOKE_AI_VERSION=${1:-latest}
|
|||||||
|
|
||||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/$INVOKE_AI_VERSION.zip"
|
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/$INVOKE_AI_VERSION.zip"
|
||||||
INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/environments-and-requirements/requirements-base.txt
|
INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/environments-and-requirements/requirements-base.txt
|
||||||
|
INVOKE_AI_MODELS=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/configs/INITIAL_MODELS.yaml
|
||||||
|
|
||||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||||
scriptdir=$(dirname "$0")
|
scriptdir=$(dirname "$0")
|
||||||
@ -44,6 +45,7 @@ echo If you do not want to do this, press control-C now!
|
|||||||
read -p "Press any key to continue, or CTRL-C to exit..."
|
read -p "Press any key to continue, or CTRL-C to exit..."
|
||||||
|
|
||||||
curl -L "$INVOKE_AI_DEP" > environments-and-requirements/requirements-base.txt
|
curl -L "$INVOKE_AI_DEP" > environments-and-requirements/requirements-base.txt
|
||||||
|
curl -L "$INVOKE_AI_MODELS" > configs/INITIAL_MODELS.yaml
|
||||||
|
|
||||||
. .venv/bin/activate
|
. .venv/bin/activate
|
||||||
|
|
||||||
|
240
ldm/generate.py
240
ldm/generate.py
@ -1,48 +1,44 @@
|
|||||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||||
import pyparsing
|
|
||||||
# Derived from source code carrying the following copyrights
|
# Derived from source code carrying the following copyrights
|
||||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||||
|
|
||||||
import torch
|
import gc
|
||||||
import numpy as np
|
import importlib
|
||||||
import random
|
|
||||||
import os
|
import os
|
||||||
import time
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import transformers
|
|
||||||
import io
|
|
||||||
import gc
|
|
||||||
import hashlib
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import diffusers
|
||||||
|
import numpy as np
|
||||||
import skimage
|
import skimage
|
||||||
|
import torch
|
||||||
from omegaconf import OmegaConf
|
import transformers
|
||||||
|
|
||||||
import ldm.invoke.conditioning
|
|
||||||
from ldm.invoke.generator.base import downsampling
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from torch import nn
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
|
from omegaconf import OmegaConf
|
||||||
from pytorch_lightning import seed_everything, logging
|
from pytorch_lightning import seed_everything, logging
|
||||||
|
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
import ldm.invoke.conditioning
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
|
||||||
from ldm.invoke.pngwriter import PngWriter
|
|
||||||
from ldm.invoke.args import metadata_from_png
|
from ldm.invoke.args import metadata_from_png
|
||||||
from ldm.invoke.image_util import InitImageResizer
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
|
||||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||||
from ldm.invoke.model_cache import ModelCache
|
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||||
from ldm.invoke.seamless import configure_model_padding
|
|
||||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
|
from ldm.invoke.globals import global_cache_dir
|
||||||
|
from ldm.invoke.image_util import InitImageResizer
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
from ldm.invoke.pngwriter import PngWriter
|
||||||
|
from ldm.invoke.seamless import configure_model_padding
|
||||||
|
from ldm.invoke.txt2mask import Txt2Mask
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
@ -160,12 +156,12 @@ class Generate:
|
|||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.height = None
|
self.height = None
|
||||||
self.width = None
|
self.width = None
|
||||||
self.model_cache = None
|
self.model_manager = None
|
||||||
self.iterations = 1
|
self.iterations = 1
|
||||||
self.steps = 50
|
self.steps = 50
|
||||||
self.cfg_scale = 7.5
|
self.cfg_scale = 7.5
|
||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self.ddim_eta = 0.0 # same seed always produces same image
|
self.ddim_eta = ddim_eta # same seed always produces same image
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.strength = 0.75
|
self.strength = 0.75
|
||||||
self.seamless = False
|
self.seamless = False
|
||||||
@ -177,7 +173,6 @@ class Generate:
|
|||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.device = None
|
self.device = None
|
||||||
self.session_peakmem = None
|
self.session_peakmem = None
|
||||||
self.generators = {}
|
|
||||||
self.base_generator = None
|
self.base_generator = None
|
||||||
self.seed = None
|
self.seed = None
|
||||||
self.outdir = outdir
|
self.outdir = outdir
|
||||||
@ -208,8 +203,14 @@ class Generate:
|
|||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
# don't accept invalid models
|
||||||
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
|
model = model or fallback
|
||||||
|
if not self.model_manager.valid_model(model):
|
||||||
|
print(f'** "{model}" is not a known model name; falling back to {fallback}.')
|
||||||
|
model = None
|
||||||
|
self.model_name = model or fallback
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
@ -225,7 +226,7 @@ class Generate:
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
safety_model_path = os.path.join(Globals.root,'models',safety_model_id)
|
safety_model_path = global_cache_dir("hub")
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
cache_dir=safety_model_path,
|
cache_dir=safety_model_path,
|
||||||
@ -404,6 +405,10 @@ class Generate:
|
|||||||
width = width or self.width
|
width = width or self.width
|
||||||
height = height or self.height
|
height = height or self.height
|
||||||
|
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
configure_model_padding(model.unet, seamless, seamless_axes)
|
||||||
|
configure_model_padding(model.vae, seamless, seamless_axes)
|
||||||
|
else:
|
||||||
configure_model_padding(model, seamless, seamless_axes)
|
configure_model_padding(model, seamless, seamless_axes)
|
||||||
|
|
||||||
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
|
||||||
@ -439,7 +444,7 @@ class Generate:
|
|||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.concept_lib().replace_concepts_with_triggers(prompt, lambda concepts: self.load_concepts(concepts))
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
||||||
|
|
||||||
# bit of a hack to change the cached sampler's karras threshold to
|
# bit of a hack to change the cached sampler's karras threshold to
|
||||||
# whatever the user asked for
|
# whatever the user asked for
|
||||||
@ -546,7 +551,7 @@ class Generate:
|
|||||||
print('**Interrupted** Partial results will be returned.')
|
print('**Interrupted** Partial results will be returned.')
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError as e:
|
except RuntimeError:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
print('>> Could not generate image.')
|
print('>> Could not generate image.')
|
||||||
|
|
||||||
@ -558,7 +563,7 @@ class Generate:
|
|||||||
)
|
)
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
print(
|
print(
|
||||||
f'>> Max VRAM used for this generation:',
|
'>> Max VRAM used for this generation:',
|
||||||
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
|
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
'Current VRAM utilization:',
|
'Current VRAM utilization:',
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
@ -568,7 +573,7 @@ class Generate:
|
|||||||
self.session_peakmem, torch.cuda.max_memory_allocated()
|
self.session_peakmem, torch.cuda.max_memory_allocated()
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f'>> Max VRAM used since script start: ',
|
'>> Max VRAM used since script start: ',
|
||||||
'%4.2fG' % (self.session_peakmem / 1e9),
|
'%4.2fG' % (self.session_peakmem / 1e9),
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
@ -644,7 +649,7 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
extend_instructions[direction]=int(pixels)
|
extend_instructions[direction]=int(pixels)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
print('** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||||
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
opt.prompt = prompt
|
opt.prompt = prompt
|
||||||
@ -692,7 +697,7 @@ class Generate:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
print(f'* please provide at least one postprocessing option, such as -G or -U')
|
print('* please provide at least one postprocessing option, such as -G or -U')
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
print(f'* postprocessing tool {tool} is not yet supported')
|
print(f'* postprocessing tool {tool} is not yet supported')
|
||||||
@ -769,75 +774,62 @@ class Generate:
|
|||||||
|
|
||||||
return init_image,init_mask
|
return init_image,init_mask
|
||||||
|
|
||||||
# lots o' repeated code here! Turn into a make_func()
|
|
||||||
def _make_base(self):
|
def _make_base(self):
|
||||||
if not self.generators.get('base'):
|
return self._load_generator('','Generator')
|
||||||
from ldm.invoke.generator import Generator
|
|
||||||
self.generators['base'] = Generator(self.model, self.precision)
|
|
||||||
return self.generators['base']
|
|
||||||
|
|
||||||
def _make_img2img(self):
|
|
||||||
if not self.generators.get('img2img'):
|
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
|
||||||
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
|
||||||
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
|
|
||||||
return self.generators['img2img']
|
|
||||||
|
|
||||||
def _make_embiggen(self):
|
|
||||||
if not self.generators.get('embiggen'):
|
|
||||||
from ldm.invoke.generator.embiggen import Embiggen
|
|
||||||
self.generators['embiggen'] = Embiggen(self.model, self.precision)
|
|
||||||
return self.generators['embiggen']
|
|
||||||
|
|
||||||
def _make_txt2img(self):
|
def _make_txt2img(self):
|
||||||
if not self.generators.get('txt2img'):
|
return self._load_generator('.txt2img','Txt2Img')
|
||||||
from ldm.invoke.generator.txt2img import Txt2Img
|
|
||||||
self.generators['txt2img'] = Txt2Img(self.model, self.precision)
|
def _make_img2img(self):
|
||||||
self.generators['txt2img'].free_gpu_mem = self.free_gpu_mem
|
return self._load_generator('.img2img','Img2Img')
|
||||||
return self.generators['txt2img']
|
|
||||||
|
def _make_embiggen(self):
|
||||||
|
return self._load_generator('.embiggen','Embiggen')
|
||||||
|
|
||||||
def _make_txt2img2img(self):
|
def _make_txt2img2img(self):
|
||||||
if not self.generators.get('txt2img2'):
|
return self._load_generator('.txt2img2img','Txt2Img2Img')
|
||||||
from ldm.invoke.generator.txt2img2img import Txt2Img2Img
|
|
||||||
self.generators['txt2img2'] = Txt2Img2Img(self.model, self.precision)
|
|
||||||
self.generators['txt2img2'].free_gpu_mem = self.free_gpu_mem
|
|
||||||
return self.generators['txt2img2']
|
|
||||||
|
|
||||||
def _make_inpaint(self):
|
def _make_inpaint(self):
|
||||||
if not self.generators.get('inpaint'):
|
return self._load_generator('.inpaint','Inpaint')
|
||||||
from ldm.invoke.generator.inpaint import Inpaint
|
|
||||||
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
|
||||||
self.generators['inpaint'].free_gpu_mem = self.free_gpu_mem
|
|
||||||
return self.generators['inpaint']
|
|
||||||
|
|
||||||
# "omnibus" supports the runwayML custom inpainting model, which does
|
|
||||||
# txt2img, img2img and inpainting using slight variations on the same code
|
|
||||||
def _make_omnibus(self):
|
def _make_omnibus(self):
|
||||||
if not self.generators.get('omnibus'):
|
return self._load_generator('.omnibus','Omnibus')
|
||||||
from ldm.invoke.generator.omnibus import Omnibus
|
|
||||||
self.generators['omnibus'] = Omnibus(self.model, self.precision)
|
def _load_generator(self, module, class_name):
|
||||||
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
|
if self.is_legacy_model(self.model_name):
|
||||||
return self.generators['omnibus']
|
mn = f'ldm.invoke.ckpt_generator{module}'
|
||||||
|
cn = f'Ckpt{class_name}'
|
||||||
|
else:
|
||||||
|
mn = f'ldm.invoke.generator{module}'
|
||||||
|
cn = class_name
|
||||||
|
module = importlib.import_module(mn)
|
||||||
|
constructor = getattr(module,cn)
|
||||||
|
return constructor(self.model, self.precision)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
'''
|
'''
|
||||||
preload model identified in self.model_name
|
preload model identified in self.model_name
|
||||||
'''
|
'''
|
||||||
self.set_model(self.model_name)
|
return self.set_model(self.model_name)
|
||||||
|
|
||||||
def set_model(self,model_name):
|
def set_model(self,model_name):
|
||||||
"""
|
"""
|
||||||
Given the name of a model defined in models.yaml, will load and initialize it
|
Given the name of a model defined in models.yaml, will load and initialize it
|
||||||
and return the model object. Previously-used models will be cached.
|
and return the model object. Previously-used models will be cached.
|
||||||
|
|
||||||
|
If the passed model_name is invalid, raises a KeyError.
|
||||||
|
If the model fails to load for some reason, will attempt to load the previously-
|
||||||
|
loaded model (if any). If that fallback fails, will raise an AssertionError
|
||||||
"""
|
"""
|
||||||
if self.model_name == model_name and self.model is not None:
|
if self.model_name == model_name and self.model is not None:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
previous_model_name = self.model_name
|
||||||
|
|
||||||
# the model cache does the loading and offloading
|
# the model cache does the loading and offloading
|
||||||
cache = self.model_cache
|
cache = self.model_manager
|
||||||
if not cache.valid_model(model_name):
|
if not cache.valid_model(model_name):
|
||||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
raise KeyError('** "{model_name}" is not a known model name. Cannot change.')
|
||||||
return self.model
|
|
||||||
|
|
||||||
cache.print_vram_usage()
|
cache.print_vram_usage()
|
||||||
|
|
||||||
@ -847,11 +839,17 @@ class Generate:
|
|||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
try:
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
if model_data is None: # restore previous
|
except Exception as e:
|
||||||
model_data = cache.get_model(self.model_name)
|
print(f'** model {model_name} could not be loaded: {str(e)}')
|
||||||
model_name = self.model_name # addresses Issue #1547
|
if previous_model_name is None:
|
||||||
|
raise e
|
||||||
|
print(f'** trying to reload previous model')
|
||||||
|
model_data = cache.get_model(previous_model_name) # load previous
|
||||||
|
if model_data is None:
|
||||||
|
raise e
|
||||||
|
model_name = previous_model_name
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.width = model_data['width']
|
self.width = model_data['width']
|
||||||
@ -863,19 +861,23 @@ class Generate:
|
|||||||
|
|
||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
self.model.embedding_manager.load(
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
for name in files:
|
||||||
)
|
ti_path = os.path.join(root, name)
|
||||||
|
self.model.textual_inversion_manager.load_textual_inversion(ti_path,
|
||||||
|
defer_injecting_tokens=True)
|
||||||
|
print(f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}')
|
||||||
|
|
||||||
self._set_sampler()
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self._set_sampler() # requires self.model_name to be set first
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def load_concepts(self,concepts:list[str]):
|
def load_huggingface_concepts(self, concepts:list[str]):
|
||||||
self.model.embedding_manager.load_concepts(concepts, self.precision=='float32' or self.precision=='autocast')
|
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||||
|
|
||||||
def concept_lib(self)->Concepts:
|
@property
|
||||||
return self.model.embedding_manager.concepts_library
|
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
|
||||||
|
return self.model.textual_inversion_manager.hf_concepts_library
|
||||||
|
|
||||||
def correct_colors(self,
|
def correct_colors(self,
|
||||||
image_list,
|
image_list,
|
||||||
@ -970,9 +972,18 @@ class Generate:
|
|||||||
def sample_to_lowres_estimated_image(self, samples):
|
def sample_to_lowres_estimated_image(self, samples):
|
||||||
return self._make_base().sample_to_lowres_estimated_image(samples)
|
return self._make_base().sample_to_lowres_estimated_image(samples)
|
||||||
|
|
||||||
|
def is_legacy_model(self,model_name)->bool:
|
||||||
|
return self.model_manager.is_legacy(model_name)
|
||||||
|
|
||||||
|
def _set_sampler(self):
|
||||||
|
if isinstance(self.model, DiffusionPipeline):
|
||||||
|
return self._set_scheduler()
|
||||||
|
else:
|
||||||
|
return self._set_sampler_legacy()
|
||||||
|
|
||||||
# very repetitive code - can this be simplified? The KSampler names are
|
# very repetitive code - can this be simplified? The KSampler names are
|
||||||
# consistent, at least
|
# consistent, at least
|
||||||
def _set_sampler(self):
|
def _set_sampler_legacy(self):
|
||||||
msg = f'>> Setting Sampler to {self.sampler_name}'
|
msg = f'>> Setting Sampler to {self.sampler_name}'
|
||||||
if self.sampler_name == 'plms':
|
if self.sampler_name == 'plms':
|
||||||
self.sampler = PLMSSampler(self.model, device=self.device)
|
self.sampler = PLMSSampler(self.model, device=self.device)
|
||||||
@ -1000,6 +1011,41 @@ class Generate:
|
|||||||
|
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
def _set_scheduler(self):
|
||||||
|
default = self.model.scheduler
|
||||||
|
|
||||||
|
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
||||||
|
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
||||||
|
# provide an alias for compatibility.
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.sampler_name in scheduler_map:
|
||||||
|
sampler_class = scheduler_map[self.sampler_name]
|
||||||
|
msg = f'>> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})'
|
||||||
|
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||||
|
else:
|
||||||
|
msg = (f'>> Unsupported Sampler: {self.sampler_name} '
|
||||||
|
f'Defaulting to {default}')
|
||||||
|
self.sampler = default
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
if not hasattr(self.sampler, 'uses_inpainting_model'):
|
||||||
|
# FIXME: terrible kludge!
|
||||||
|
self.sampler.uses_inpainting_model = lambda: False
|
||||||
|
|
||||||
def _load_img(self, img)->Image:
|
def _load_img(self, img)->Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
|
@ -2,11 +2,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import shlex
|
import shlex
|
||||||
import copy
|
|
||||||
import warnings
|
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
import yaml
|
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
@ -16,9 +12,9 @@ from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_f
|
|||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||||
from ldm.invoke.image_util import make_grid
|
from ldm.invoke.image_util import make_grid
|
||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from argparse import Namespace
|
||||||
import pyparsing
|
import pyparsing
|
||||||
import ldm.invoke
|
import ldm.invoke
|
||||||
|
|
||||||
@ -45,14 +41,20 @@ def main():
|
|||||||
print('--max_loaded_models must be >= 1; using 1')
|
print('--max_loaded_models must be >= 1; using 1')
|
||||||
args.max_loaded_models = 1
|
args.max_loaded_models = 1
|
||||||
|
|
||||||
|
# alert - setting a global here
|
||||||
|
Globals.try_patchmatch = args.patchmatch
|
||||||
|
Globals.always_use_cpu = args.always_use_cpu
|
||||||
|
Globals.internet_available = args.internet_available and check_internet()
|
||||||
|
print(f'>> Internet connectivity is {Globals.internet_available}')
|
||||||
|
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
if not os.path.exists(os.path.join(Globals.root,'configs','models.yaml')):
|
if not os.path.exists(os.path.join(Globals.root,'configs','models.yaml')):
|
||||||
print(f"\n** Error. The file {os.path.join(Globals.root,'configs','models.yaml')} could not be found.")
|
print(f"\n** Error. The file {os.path.join(Globals.root,'configs','models.yaml')} could not be found.")
|
||||||
print(f'** Please check the location of your invokeai directory and use the --root_dir option to point to the correct path.')
|
print('** Please check the location of your invokeai directory and use the --root_dir option to point to the correct path.')
|
||||||
print(f'** This script will now exit.')
|
print('** This script will now exit.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
print(f'>> {ldm.invoke.__app_name__} {ldm.invoke.__version__}')
|
print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}')
|
||||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
@ -78,6 +80,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
embedding_path = None
|
embedding_path = None
|
||||||
|
|
||||||
|
# migrate legacy models
|
||||||
|
ModelManager.migrate_models()
|
||||||
|
|
||||||
# load the infile as a list of lines
|
# load the infile as a list of lines
|
||||||
if opt.infile:
|
if opt.infile:
|
||||||
try:
|
try:
|
||||||
@ -107,9 +112,8 @@ def main():
|
|||||||
safety_checker=opt.safety_checker,
|
safety_checker=opt.safety_checker,
|
||||||
max_loaded_models=opt.max_loaded_models,
|
max_loaded_models=opt.max_loaded_models,
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError):
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
emergency_model_reconfigure(opt)
|
report_model_error(opt,e)
|
||||||
sys.exit(-1)
|
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
@ -120,9 +124,18 @@ def main():
|
|||||||
# preload the model
|
# preload the model
|
||||||
try:
|
try:
|
||||||
gen.load_model()
|
gen.load_model()
|
||||||
except AssertionError:
|
except KeyError as e:
|
||||||
emergency_model_reconfigure(opt)
|
pass
|
||||||
sys.exit(-1)
|
except Exception as e:
|
||||||
|
report_model_error(opt, e)
|
||||||
|
|
||||||
|
# try to autoconvert new models
|
||||||
|
# autoimport new .ckpt files
|
||||||
|
if path := opt.autoconvert:
|
||||||
|
gen.model_manager.autoconvert_weights(
|
||||||
|
conf_path=opt.conf,
|
||||||
|
weights_directory=path,
|
||||||
|
)
|
||||||
|
|
||||||
# web server loops forever
|
# web server loops forever
|
||||||
if opt.web or opt.gui:
|
if opt.web or opt.gui:
|
||||||
@ -138,6 +151,9 @@ def main():
|
|||||||
main_loop(gen, opt)
|
main_loop(gen, opt)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}')
|
print(f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}')
|
||||||
|
except Exception:
|
||||||
|
print(">> An error occurred:")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, opt):
|
def main_loop(gen, opt):
|
||||||
@ -147,13 +163,13 @@ def main_loop(gen, opt):
|
|||||||
doneAfterInFile = infile is not None
|
doneAfterInFile = infile is not None
|
||||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||||
last_results = list()
|
last_results = list()
|
||||||
model_config = OmegaConf.load(opt.conf)
|
|
||||||
|
|
||||||
# The readline completer reads history from the .dream_history file located in the
|
# The readline completer reads history from the .dream_history file located in the
|
||||||
# output directory specified at the time of script launch. We do not currently support
|
# output directory specified at the time of script launch. We do not currently support
|
||||||
# changing the history file midstream when the output directory is changed.
|
# changing the history file midstream when the output directory is changed.
|
||||||
completer = get_completer(opt, models=list(model_config.keys()))
|
completer = get_completer(opt, models=gen.model_manager.list_models())
|
||||||
set_default_output_dir(opt, completer)
|
set_default_output_dir(opt, completer)
|
||||||
|
if gen.model:
|
||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
output_cntr = completer.get_current_history_length()+1
|
output_cntr = completer.get_current_history_length()+1
|
||||||
|
|
||||||
@ -170,7 +186,7 @@ def main_loop(gen, opt):
|
|||||||
operation = 'generate'
|
operation = 'generate'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command = get_next_command(infile)
|
command = get_next_command(infile, gen.model_name)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
done = infile is None or doneAfterInFile
|
done = infile is None or doneAfterInFile
|
||||||
infile = None
|
infile = None
|
||||||
@ -315,7 +331,7 @@ def main_loop(gen, opt):
|
|||||||
if use_prefix is not None:
|
if use_prefix is not None:
|
||||||
prefix = use_prefix
|
prefix = use_prefix
|
||||||
postprocessed = upscaled if upscaled else operation=='postprocess'
|
postprocessed = upscaled if upscaled else operation=='postprocess'
|
||||||
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
|
opt.prompt = gen.huggingface_concepts_library.replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
|
||||||
filename, formatted_dream_prompt = prepare_image_metadata(
|
filename, formatted_dream_prompt = prepare_image_metadata(
|
||||||
opt,
|
opt,
|
||||||
prefix,
|
prefix,
|
||||||
@ -434,24 +450,50 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
|
|
||||||
elif command.startswith('!switch'):
|
elif command.startswith('!switch'):
|
||||||
model_name = command.replace('!switch ','',1)
|
model_name = command.replace('!switch ','',1)
|
||||||
|
try:
|
||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
|
except KeyError as e:
|
||||||
|
print(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
report_model_error(opt,e)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!models'):
|
elif command.startswith('!models'):
|
||||||
gen.model_cache.print_models()
|
gen.model_manager.print_models()
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
elif command.startswith('!import'):
|
elif command.startswith('!import'):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
print('** please provide a path to a .ckpt or .vae model file')
|
print('** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1')
|
||||||
elif not os.path.exists(path[1]):
|
|
||||||
print(f'** {path[1]}: file not found')
|
|
||||||
else:
|
else:
|
||||||
add_weights_to_config(path[1], gen, opt, completer)
|
import_model(path[1], gen, opt, completer)
|
||||||
|
completer.add_history(command)
|
||||||
|
operation = None
|
||||||
|
|
||||||
|
elif command.startswith('!convert'):
|
||||||
|
path = shlex.split(command)
|
||||||
|
if len(path) < 2:
|
||||||
|
print('** please provide the path to a .ckpt or .safetensors model')
|
||||||
|
elif not os.path.exists(path[1]):
|
||||||
|
print(f'** {path[1]}: model not found')
|
||||||
|
else:
|
||||||
|
optimize_model(path[1], gen, opt, completer)
|
||||||
|
completer.add_history(command)
|
||||||
|
operation = None
|
||||||
|
|
||||||
|
|
||||||
|
elif command.startswith('!optimize'):
|
||||||
|
path = shlex.split(command)
|
||||||
|
if len(path) < 2:
|
||||||
|
print('** please provide an installed model name')
|
||||||
|
elif not path[1] in gen.model_manager.list_models():
|
||||||
|
print(f'** {path[1]}: model not found')
|
||||||
|
else:
|
||||||
|
optimize_model(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
@ -460,7 +502,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
|||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
print('** please provide the name of a model')
|
print('** please provide the name of a model')
|
||||||
else:
|
else:
|
||||||
edit_config(path[1], gen, opt, completer)
|
edit_model(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
operation = None
|
operation = None
|
||||||
|
|
||||||
@ -521,121 +563,223 @@ def set_default_output_dir(opt:Args, completer:Completer):
|
|||||||
completer.set_default_dir(opt.outdir)
|
completer.set_default_dir(opt.outdir)
|
||||||
|
|
||||||
|
|
||||||
def add_weights_to_config(model_path:str, gen, opt, completer):
|
def import_model(model_path:str, gen, opt, completer):
|
||||||
print(f'>> Model import in process. Please enter the values needed to configure this model:')
|
'''
|
||||||
print()
|
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
|
||||||
|
(3) a huggingface repository id
|
||||||
|
'''
|
||||||
|
model_name = None
|
||||||
|
|
||||||
new_config = {}
|
if model_path.startswith(('http:','https:','ftp:')):
|
||||||
new_config['weights'] = model_path
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
|
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
||||||
done = False
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
while not done:
|
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||||
model_name = input('Short name for this model: ')
|
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||||
if not re.match('^[\w._-]+$',model_name):
|
elif os.path.isdir(model_path):
|
||||||
print('** model name must contain only words, digits and the characters [._-] **')
|
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||||
else:
|
else:
|
||||||
done = True
|
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
|
||||||
new_config['description'] = input('Description of this model: ')
|
|
||||||
|
if not model_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _verify_load(model_name, gen):
|
||||||
|
print('** model failed to load. Discarding configuration entry')
|
||||||
|
gen.model_manager.del_model(model_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
if input('Make this the default model? [n] ') in ('y','Y'):
|
||||||
|
gen.model_manager.set_default_model(model_name)
|
||||||
|
|
||||||
|
gen.model_manager.commit(opt.conf)
|
||||||
|
completer.update_models(gen.model_manager.list_models())
|
||||||
|
print(f'>> {model_name} successfully installed')
|
||||||
|
|
||||||
|
def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
|
||||||
|
manager = gen.model_manager
|
||||||
|
default_name = Path(path_or_repo).stem
|
||||||
|
default_description = f'Imported model {default_name}'
|
||||||
|
model_name, model_description = _get_model_name_and_desc(
|
||||||
|
manager,
|
||||||
|
completer,
|
||||||
|
model_name=default_name,
|
||||||
|
model_description=default_description
|
||||||
|
)
|
||||||
|
|
||||||
|
if not manager.import_diffuser_model(
|
||||||
|
path_or_repo,
|
||||||
|
model_name = model_name,
|
||||||
|
description = model_description):
|
||||||
|
print('** model failed to import')
|
||||||
|
return None
|
||||||
|
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||||
|
manager.set_default_model(model_name)
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||||
|
manager = gen.model_manager
|
||||||
|
default_name = Path(path_or_url).stem
|
||||||
|
default_description = f'Imported model {default_name}'
|
||||||
|
model_name, model_description = _get_model_name_and_desc(
|
||||||
|
manager,
|
||||||
|
completer,
|
||||||
|
model_name=default_name,
|
||||||
|
model_description=default_description
|
||||||
|
)
|
||||||
|
config_file = None
|
||||||
|
|
||||||
completer.complete_extensions(('.yaml','.yml'))
|
completer.complete_extensions(('.yaml','.yml'))
|
||||||
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
|
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
||||||
|
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
new_config['config'] = input('Configuration file for this model: ')
|
config_file = input('Configuration file for this model: ').strip()
|
||||||
done = os.path.exists(new_config['config'])
|
done = os.path.exists(config_file)
|
||||||
|
|
||||||
done = False
|
|
||||||
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
|
|
||||||
while not done:
|
|
||||||
vae = input('VAE autoencoder file for this model [None]: ')
|
|
||||||
if os.path.exists(vae):
|
|
||||||
new_config['vae'] = vae
|
|
||||||
done = True
|
|
||||||
else:
|
|
||||||
done = len(vae)==0
|
|
||||||
|
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
|
||||||
for field in ('width','height'):
|
if not manager.import_ckpt_model(
|
||||||
done = False
|
path_or_url,
|
||||||
while not done:
|
config = config_file,
|
||||||
try:
|
model_name = model_name,
|
||||||
completer.linebuffer = '512'
|
model_description = model_description,
|
||||||
value = int(input(f'Default image {field}: '))
|
commit_to_conf = opt.conf,
|
||||||
assert value >= 64 and value <= 2048
|
):
|
||||||
new_config[field] = value
|
print('** model failed to import')
|
||||||
done = True
|
return None
|
||||||
except:
|
|
||||||
print('** Please enter a valid integer between 64 and 2048')
|
|
||||||
|
|
||||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||||
|
manager.set_model_default(model_name)
|
||||||
|
return model_name
|
||||||
|
|
||||||
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
|
def _verify_load(model_name:str, gen)->bool:
|
||||||
completer.add_model(model_name)
|
print('>> Verifying that new model loads...')
|
||||||
|
current_model = gen.model_name
|
||||||
|
if not gen.model_manager.get_model(model_name):
|
||||||
|
return False
|
||||||
|
do_switch = input('Keep model loaded? [y] ')
|
||||||
|
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||||
|
gen.set_model(model_name)
|
||||||
|
else:
|
||||||
|
print('>> Restoring previous model')
|
||||||
|
gen.set_model(current_model)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''):
|
||||||
|
model_name = _get_model_name(model_manager.list_models(),completer,model_name)
|
||||||
|
completer.set_line(model_description)
|
||||||
|
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
|
||||||
|
return model_name, model_description
|
||||||
|
|
||||||
|
def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||||
|
manager = gen.model_manager
|
||||||
|
ckpt_path = None
|
||||||
|
|
||||||
|
if (model_info := manager.model_info(model_name_or_path)):
|
||||||
|
if 'weights' in model_info:
|
||||||
|
ckpt_path = Path(model_info['weights'])
|
||||||
|
model_name = model_name_or_path
|
||||||
|
model_description = model_info['description']
|
||||||
|
else:
|
||||||
|
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
|
||||||
|
return
|
||||||
|
elif os.path.exists(model_name_or_path):
|
||||||
|
ckpt_path = Path(model_name_or_path)
|
||||||
|
model_name,model_description = _get_model_name_and_desc(
|
||||||
|
manager,
|
||||||
|
completer,
|
||||||
|
ckpt_path.stem,
|
||||||
|
f'Converted model {ckpt_path.stem}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not ckpt_path.is_absolute():
|
||||||
|
ckpt_path = Path(Globals.root,ckpt_path)
|
||||||
|
|
||||||
|
diffuser_path = Path(Globals.root, 'models','optimized-ckpts',model_name)
|
||||||
|
if diffuser_path.exists():
|
||||||
|
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
|
||||||
|
return
|
||||||
|
|
||||||
|
new_config = gen.model_manager.convert_and_import(
|
||||||
|
ckpt_path,
|
||||||
|
diffuser_path,
|
||||||
|
model_name=model_name,
|
||||||
|
model_description=model_description,
|
||||||
|
commit_to_conf=opt.conf,
|
||||||
|
)
|
||||||
|
if not new_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
completer.update_models(gen.model_manager.list_models())
|
||||||
|
if input(f'Load optimized model {model_name}? [y] ') not in ('n','N'):
|
||||||
|
gen.set_model(model_name)
|
||||||
|
|
||||||
|
response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
|
||||||
|
if response.startswith(('y','Y')):
|
||||||
|
ckpt_path.unlink(missing_ok=True)
|
||||||
|
print(f'{ckpt_path} deleted')
|
||||||
|
|
||||||
def del_config(model_name:str, gen, opt, completer):
|
def del_config(model_name:str, gen, opt, completer):
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
if model_name == current_model:
|
if model_name == current_model:
|
||||||
print("** Can't delete active model. !switch to another model first. **")
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
gen.model_cache.del_model(model_name)
|
gen.model_manager.del_model(model_name)
|
||||||
gen.model_cache.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
print(f'** {model_name} deleted')
|
print(f'** {model_name} deleted')
|
||||||
completer.del_model(model_name)
|
completer.update_models(gen.model_manager.list_models())
|
||||||
|
|
||||||
def edit_config(model_name:str, gen, opt, completer):
|
def edit_model(model_name:str, gen, opt, completer):
|
||||||
config = gen.model_cache.config
|
current_model = gen.model_name
|
||||||
|
# if model_name == current_model:
|
||||||
|
# print("** Can't edit the active model. !switch to another model first. **")
|
||||||
|
# return
|
||||||
|
|
||||||
if model_name not in config:
|
manager = gen.model_manager
|
||||||
|
if not (info := manager.model_info(model_name)):
|
||||||
print(f'** Unknown model {model_name}')
|
print(f'** Unknown model {model_name}')
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
|
||||||
|
new_name = _get_model_name(manager.list_models(),completer,model_name)
|
||||||
|
|
||||||
conf = config[model_name]
|
for attribute in info.keys():
|
||||||
new_config = {}
|
if type(info[attribute]) != str:
|
||||||
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
|
continue
|
||||||
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
|
if attribute == 'format':
|
||||||
completer.linebuffer = str(conf[field]) if field in conf else ''
|
continue
|
||||||
new_value = input(f'{field}: ')
|
completer.set_line(info[attribute])
|
||||||
new_config[field] = int(new_value) if field in ('width','height') else new_value
|
info[attribute] = input(f'{attribute}: ') or info[attribute]
|
||||||
make_default = input('Make this the default model? [n] ') in ('y','Y')
|
|
||||||
completer.complete_extensions(None)
|
|
||||||
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
|
|
||||||
|
|
||||||
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
|
if new_name != model_name:
|
||||||
current_model = gen.model_name
|
manager.del_model(model_name)
|
||||||
|
|
||||||
op = 'modify' if clobber else 'import'
|
# this does the update
|
||||||
print('\n>> New configuration:')
|
manager.add_model(new_name, info, True)
|
||||||
if make_default:
|
|
||||||
new_config['default'] = True
|
|
||||||
print(yaml.dump({model_name:new_config}))
|
|
||||||
if input(f'OK to {op} [n]? ') not in ('y','Y'):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||||
print('>> Verifying that new model loads...')
|
manager.set_default_model(new_name)
|
||||||
gen.model_cache.add_model(model_name, new_config, clobber)
|
manager.commit(opt.conf)
|
||||||
assert gen.set_model(model_name) is not None, 'model failed to load'
|
completer.update_models(manager.list_models())
|
||||||
except AssertionError as e:
|
print('>> Model successfully updated')
|
||||||
print(f'** aborting **')
|
|
||||||
gen.model_cache.del_model(model_name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if make_default:
|
def _get_model_name(existing_names,completer,default_name:str='')->str:
|
||||||
print('making this default')
|
done = False
|
||||||
gen.model_cache.set_default_model(model_name)
|
completer.set_line(default_name)
|
||||||
|
while not done:
|
||||||
gen.model_cache.commit(conf_path)
|
model_name = input(f'Short name for this model [{default_name}]: ').strip()
|
||||||
|
if len(model_name)==0:
|
||||||
do_switch = input(f'Keep model loaded? [y]')
|
model_name = default_name
|
||||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
if not re.match('^[\w._+-]+$',model_name):
|
||||||
pass
|
print('** model name must contain only words, digits and the characters "._+-" **')
|
||||||
|
elif model_name != default_name and model_name in existing_names:
|
||||||
|
print(f'** the name {model_name} is already in use. Pick another.')
|
||||||
else:
|
else:
|
||||||
gen.set_model(current_model)
|
done = True
|
||||||
return True
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def do_textmask(gen, opt, callback):
|
def do_textmask(gen, opt, callback):
|
||||||
image_path = opt.prompt
|
image_path = opt.prompt
|
||||||
@ -746,7 +890,7 @@ def prepare_image_metadata(
|
|||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
||||||
filename = f'{prefix}.{seed}.png'
|
filename = f'{prefix}.{seed}.png'
|
||||||
except IndexError as e:
|
except IndexError:
|
||||||
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
||||||
filename = f'{prefix}.{seed}.png'
|
filename = f'{prefix}.{seed}.png'
|
||||||
|
|
||||||
@ -782,9 +926,9 @@ def choose_postprocess_name(opt,prefix,seed) -> str:
|
|||||||
counter += 1
|
counter += 1
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
def get_next_command(infile=None) -> str: # command string
|
def get_next_command(infile=None, model_name='no model') -> str: # command string
|
||||||
if infile is None:
|
if infile is None:
|
||||||
command = input('invoke> ')
|
command = input(f'({model_name}) invoke> ').strip()
|
||||||
else:
|
else:
|
||||||
command = infile.readline()
|
command = infile.readline()
|
||||||
if not command:
|
if not command:
|
||||||
@ -815,7 +959,8 @@ def add_embedding_terms(gen,completer):
|
|||||||
Called after setting the model, updates the autocompleter with
|
Called after setting the model, updates the autocompleter with
|
||||||
any terms loaded by the embedding manager.
|
any terms loaded by the embedding manager.
|
||||||
'''
|
'''
|
||||||
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
|
trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
|
completer.add_embedding_terms(trigger_strings)
|
||||||
|
|
||||||
def split_variations(variations_string) -> list:
|
def split_variations(variations_string) -> list:
|
||||||
# shotgun parsing, woo
|
# shotgun parsing, woo
|
||||||
@ -938,13 +1083,13 @@ def write_commands(opt, file_path:str, outfilepath:str):
|
|||||||
f.write('\n'.join(commands))
|
f.write('\n'.join(commands))
|
||||||
print(f'>> File {outfilepath} with commands created')
|
print(f'>> File {outfilepath} with commands created')
|
||||||
|
|
||||||
def emergency_model_reconfigure(opt):
|
def report_model_error(opt:Namespace, e:Exception):
|
||||||
print()
|
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
print('** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models.')
|
||||||
print(' You appear to have a missing or misconfigured model file(s). ')
|
response = input('Do you want to run configure_invokeai.py to select and/or reinstall models? [y] ')
|
||||||
print(' The script will now exit and run configure_invokeai.py to help fix the problem.')
|
if response.startswith(('n','N')):
|
||||||
print(' After reconfiguration is done, please relaunch invoke.py. ')
|
return
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
|
||||||
print('configure_invokeai is launching....\n')
|
print('configure_invokeai is launching....\n')
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
@ -952,7 +1097,7 @@ def emergency_model_reconfigure(opt):
|
|||||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||||
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
|
yes_to_all = os.environ.get('INVOKE_MODEL_RECONFIGURE')
|
||||||
|
previous_args = sys.argv
|
||||||
sys.argv = [ 'configure_invokeai' ]
|
sys.argv = [ 'configure_invokeai' ]
|
||||||
sys.argv.extend(root_dir)
|
sys.argv.extend(root_dir)
|
||||||
sys.argv.extend(config)
|
sys.argv.extend(config)
|
||||||
@ -961,3 +1106,20 @@ def emergency_model_reconfigure(opt):
|
|||||||
|
|
||||||
import configure_invokeai
|
import configure_invokeai
|
||||||
configure_invokeai.main()
|
configure_invokeai.main()
|
||||||
|
print('** InvokeAI will now restart')
|
||||||
|
sys.argv = previous_args
|
||||||
|
main() # would rather do a os.exec(), but doesn't exist?
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def check_internet()->bool:
|
||||||
|
'''
|
||||||
|
Return true if the internet is reachable.
|
||||||
|
It does this by pinging huggingface.co.
|
||||||
|
'''
|
||||||
|
import urllib.request
|
||||||
|
host = 'http://huggingface.co'
|
||||||
|
try:
|
||||||
|
urllib.request.urlopen(host,timeout=1)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__='2.2.6+a0'
|
__version__='2.3.0+a0'
|
||||||
|
@ -81,22 +81,23 @@ with metadata_from_png():
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from argparse import Namespace, RawTextHelpFormatter
|
|
||||||
import pydoc
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import shlex
|
|
||||||
import copy
|
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import warnings
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pydoc
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
import sys
|
||||||
import ldm.invoke
|
import ldm.invoke
|
||||||
import ldm.invoke.pngwriter
|
import ldm.invoke.pngwriter
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
APP_ID = ldm.invoke.__app_id__
|
APP_ID = ldm.invoke.__app_id__
|
||||||
APP_NAME = ldm.invoke.__app_name__
|
APP_NAME = ldm.invoke.__app_name__
|
||||||
@ -113,6 +114,8 @@ SAMPLER_CHOICES = [
|
|||||||
'k_heun',
|
'k_heun',
|
||||||
'k_lms',
|
'k_lms',
|
||||||
'plms',
|
'plms',
|
||||||
|
# diffusers:
|
||||||
|
"pndm",
|
||||||
]
|
]
|
||||||
|
|
||||||
PRECISION_CHOICES = [
|
PRECISION_CHOICES = [
|
||||||
@ -181,7 +184,7 @@ class Args(object):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
print('* Initializing, be patient...')
|
print('* Initializing, be patient...')
|
||||||
Globals.root = os.path.abspath(switches.root_dir or Globals.root)
|
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||||
Globals.try_patchmatch = switches.patchmatch
|
Globals.try_patchmatch = switches.patchmatch
|
||||||
|
|
||||||
# now use root directory to find the init file
|
# now use root directory to find the init file
|
||||||
@ -273,7 +276,7 @@ class Args(object):
|
|||||||
switches.append(f'-I {a["init_img"]}')
|
switches.append(f'-I {a["init_img"]}')
|
||||||
switches.append(f'-A {a["sampler_name"]}')
|
switches.append(f'-A {a["sampler_name"]}')
|
||||||
if a['fit']:
|
if a['fit']:
|
||||||
switches.append(f'--fit')
|
switches.append('--fit')
|
||||||
if a['init_mask'] and len(a['init_mask'])>0:
|
if a['init_mask'] and len(a['init_mask'])>0:
|
||||||
switches.append(f'-M {a["init_mask"]}')
|
switches.append(f'-M {a["init_mask"]}')
|
||||||
if a['init_color'] and len(a['init_color'])>0:
|
if a['init_color'] and len(a['init_color'])>0:
|
||||||
@ -281,7 +284,7 @@ class Args(object):
|
|||||||
if a['strength'] and a['strength']>0:
|
if a['strength'] and a['strength']>0:
|
||||||
switches.append(f'-f {a["strength"]}')
|
switches.append(f'-f {a["strength"]}')
|
||||||
if a['inpaint_replace']:
|
if a['inpaint_replace']:
|
||||||
switches.append(f'--inpaint_replace')
|
switches.append('--inpaint_replace')
|
||||||
if a['text_mask']:
|
if a['text_mask']:
|
||||||
switches.append(f'-tm {" ".join([str(u) for u in a["text_mask"]])}')
|
switches.append(f'-tm {" ".join([str(u) for u in a["text_mask"]])}')
|
||||||
else:
|
else:
|
||||||
@ -479,6 +482,12 @@ class Args(object):
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='Force free gpu memory before final decoding',
|
help='Force free gpu memory before final decoding',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--always_use_cpu",
|
||||||
|
dest="always_use_cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Force use of CPU even if GPU is available"
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--precision',
|
'--precision',
|
||||||
dest='precision',
|
dest='precision',
|
||||||
@ -489,13 +498,26 @@ class Args(object):
|
|||||||
default='auto',
|
default='auto',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--nsfw_checker'
|
'--internet',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
dest='internet_available',
|
||||||
|
default=True,
|
||||||
|
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--nsfw_checker',
|
||||||
'--safety_checker',
|
'--safety_checker',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest='safety_checker',
|
dest='safety_checker',
|
||||||
default=False,
|
default=False,
|
||||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--autoconvert',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='Check the indicated directory for .ckpt weights files at startup and import as optimized diffuser models',
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--patchmatch',
|
'--patchmatch',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
@ -720,7 +742,11 @@ class Args(object):
|
|||||||
*Model manipulation*
|
*Model manipulation*
|
||||||
!models -- list models in configs/models.yaml
|
!models -- list models in configs/models.yaml
|
||||||
!switch <model_name> -- switch to model named <model_name>
|
!switch <model_name> -- switch to model named <model_name>
|
||||||
!import_model path/to/weights/file.ckpt -- adds a model to your config
|
!import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config
|
||||||
|
!import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config
|
||||||
|
!import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config
|
||||||
|
!optimize_model <model_name> -- converts a .ckpt model to a diffusers model
|
||||||
|
!convert_model /path/to/weights/file.ckpt -- converts a .ckpt file path to a diffusers model
|
||||||
!edit_model <model_name> -- edit a model's description
|
!edit_model <model_name> -- edit a model's description
|
||||||
!del_model <model_name> -- delete a model
|
!del_model <model_name> -- delete a model
|
||||||
"""
|
"""
|
||||||
@ -1061,7 +1087,7 @@ class Args(object):
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
|
print('format_metadata() is deprecated. Please use metadata_dumps()')
|
||||||
return metadata_dumps(kwargs)
|
return metadata_dumps(kwargs)
|
||||||
|
|
||||||
def metadata_dumps(opt,
|
def metadata_dumps(opt,
|
||||||
@ -1128,7 +1154,7 @@ def metadata_dumps(opt,
|
|||||||
rfc_dict.pop('strength')
|
rfc_dict.pop('strength')
|
||||||
|
|
||||||
if len(seeds)==0 and opt.seed:
|
if len(seeds)==0 and opt.seed:
|
||||||
seeds=[seed]
|
seeds=[opt.seed]
|
||||||
|
|
||||||
if opt.grid:
|
if opt.grid:
|
||||||
images = []
|
images = []
|
||||||
@ -1199,7 +1225,7 @@ def metadata_loads(metadata) -> list:
|
|||||||
opt = Args()
|
opt = Args()
|
||||||
opt._cmd_switches = Namespace(**image)
|
opt._cmd_switches = Namespace(**image)
|
||||||
results.append(opt)
|
results.append(opt)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
import sys, traceback
|
import sys, traceback
|
||||||
print('>> could not read metadata',file=sys.stderr)
|
print('>> could not read metadata',file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
4
ldm/invoke/ckpt_generator/__init__.py
Normal file
4
ldm/invoke/ckpt_generator/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
'''
|
||||||
|
Initialization file for the ldm.invoke.generator package
|
||||||
|
'''
|
||||||
|
from .base import CkptGenerator
|
338
ldm/invoke/ckpt_generator/base.py
Normal file
338
ldm/invoke/ckpt_generator/base.py
Normal file
@ -0,0 +1,338 @@
|
|||||||
|
'''
|
||||||
|
Base class for ldm.invoke.ckpt_generator.*
|
||||||
|
including img2img, txt2img, and inpaint
|
||||||
|
|
||||||
|
THESE MODULES ARE TRANSITIONAL AND WILL BE REMOVED AT A FUTURE DATE
|
||||||
|
WHEN LEGACY CKPT MODEL SUPPORT IS DISCONTINUED.
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import traceback
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from PIL import Image, ImageFilter, ImageChops
|
||||||
|
import cv2 as cv
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
|
downsampling = 8
|
||||||
|
CAUTION_IMG = 'assets/caution.png'
|
||||||
|
|
||||||
|
class CkptGenerator():
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
self.model = model
|
||||||
|
self.precision = precision
|
||||||
|
self.seed = None
|
||||||
|
self.latent_channels = model.channels
|
||||||
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||||
|
self.safety_checker = None
|
||||||
|
self.perlin = 0.0
|
||||||
|
self.threshold = 0
|
||||||
|
self.variation_amount = 0
|
||||||
|
self.with_variations = []
|
||||||
|
self.use_mps_noise = False
|
||||||
|
self.free_gpu_mem = None
|
||||||
|
self.caution_img = None
|
||||||
|
|
||||||
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
|
def get_make_image(self,prompt,**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
|
Return value depends on the seed at the time you call it
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||||
|
|
||||||
|
def set_variation(self, seed, variation_amount, with_variations):
|
||||||
|
self.seed = seed
|
||||||
|
self.variation_amount = variation_amount
|
||||||
|
self.with_variations = with_variations
|
||||||
|
|
||||||
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
|
safety_checker:dict=None,
|
||||||
|
attention_maps_callback = None,
|
||||||
|
**kwargs):
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
self.safety_checker = safety_checker
|
||||||
|
attention_maps_images = []
|
||||||
|
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||||
|
make_image = self.get_make_image(
|
||||||
|
prompt,
|
||||||
|
sampler = sampler,
|
||||||
|
init_image = init_image,
|
||||||
|
width = width,
|
||||||
|
height = height,
|
||||||
|
step_callback = step_callback,
|
||||||
|
threshold = threshold,
|
||||||
|
perlin = perlin,
|
||||||
|
attention_maps_callback = attention_maps_callback,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||||
|
first_seed = seed
|
||||||
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||||
|
|
||||||
|
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||||
|
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
for n in trange(iterations, desc='Generating'):
|
||||||
|
x_T = None
|
||||||
|
if self.variation_amount > 0:
|
||||||
|
seed_everything(seed)
|
||||||
|
target_noise = self.get_noise(width,height)
|
||||||
|
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||||
|
elif initial_noise is not None:
|
||||||
|
# i.e. we specified particular variations
|
||||||
|
x_T = initial_noise
|
||||||
|
else:
|
||||||
|
seed_everything(seed)
|
||||||
|
try:
|
||||||
|
x_T = self.get_noise(width,height)
|
||||||
|
except:
|
||||||
|
print('** An error occurred while getting initial noise **')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
image = make_image(x_T)
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
image = self.safety_check(image)
|
||||||
|
|
||||||
|
results.append([image, seed])
|
||||||
|
|
||||||
|
if image_callback is not None:
|
||||||
|
attention_maps_image = None if len(attention_maps_images)==0 else attention_maps_images[-1]
|
||||||
|
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_image)
|
||||||
|
|
||||||
|
seed = self.new_seed()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def sample_to_image(self,samples)->Image.Image:
|
||||||
|
"""
|
||||||
|
Given samples returned from a sampler, converts
|
||||||
|
it into a PIL Image
|
||||||
|
"""
|
||||||
|
x_samples = self.model.decode_first_stage(samples)
|
||||||
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
if len(x_samples) != 1:
|
||||||
|
raise Exception(
|
||||||
|
f'>> expected to get a single image, but got {len(x_samples)}')
|
||||||
|
x_sample = 255.0 * rearrange(
|
||||||
|
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||||
|
)
|
||||||
|
return Image.fromarray(x_sample.astype(np.uint8))
|
||||||
|
|
||||||
|
# write an approximate RGB image from latent samples for a single step to PNG
|
||||||
|
|
||||||
|
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image:
|
||||||
|
if init_image is None or init_mask is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Get the original alpha channel of the mask if there is one.
|
||||||
|
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||||
|
pil_init_mask = init_mask.getchannel('A') if init_mask.mode == 'RGBA' else init_mask.convert('L')
|
||||||
|
pil_init_image = init_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||||
|
|
||||||
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||||
|
init_rgb_pixels = np.asarray(init_image.convert('RGB'), dtype=np.uint8)
|
||||||
|
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
|
||||||
|
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Get numpy version of result
|
||||||
|
np_image = np.asarray(result, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Mask and calculate mean and standard deviation
|
||||||
|
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||||
|
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||||
|
np_image_masked = np_image[mask_pixels, :]
|
||||||
|
|
||||||
|
if np_init_rgb_pixels_masked.size > 0:
|
||||||
|
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||||
|
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||||
|
gen_means = np_image_masked.mean(axis=0)
|
||||||
|
gen_std = np_image_masked.std(axis=0)
|
||||||
|
|
||||||
|
# Color correct
|
||||||
|
np_matched_result = np_image.copy()
|
||||||
|
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
|
||||||
|
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||||
|
else:
|
||||||
|
matched_result = Image.fromarray(np_image, mode='RGB')
|
||||||
|
|
||||||
|
# Blur the mask out (into init image) by specified amount
|
||||||
|
if mask_blur_radius > 0:
|
||||||
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
|
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||||
|
pmd = Image.fromarray(nmd, mode='L')
|
||||||
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||||
|
else:
|
||||||
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
|
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||||
|
|
||||||
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
|
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||||
|
return matched_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def sample_to_lowres_estimated_image(self,samples):
|
||||||
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
|
||||||
|
# these updated numbers for v1.5 are from @torridgristle
|
||||||
|
v1_5_latent_rgb_factors = torch.tensor([
|
||||||
|
# R G B
|
||||||
|
[ 0.3444, 0.1385, 0.0670], # L1
|
||||||
|
[ 0.1247, 0.4027, 0.1494], # L2
|
||||||
|
[-0.3192, 0.2513, 0.2103], # L3
|
||||||
|
[-0.1307, -0.1874, -0.7445] # L4
|
||||||
|
], dtype=samples.dtype, device=samples.device)
|
||||||
|
|
||||||
|
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||||
|
latents_ubyte = (((latent_image + 1) / 2)
|
||||||
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
.byte()).cpu()
|
||||||
|
|
||||||
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
def generate_initial_noise(self, seed, width, height):
|
||||||
|
initial_noise = None
|
||||||
|
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||||
|
# use fixed initial noise plus random noise per iteration
|
||||||
|
seed_everything(seed)
|
||||||
|
initial_noise = self.get_noise(width,height)
|
||||||
|
for v_seed, v_weight in self.with_variations:
|
||||||
|
seed = v_seed
|
||||||
|
seed_everything(seed)
|
||||||
|
next_noise = self.get_noise(width,height)
|
||||||
|
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||||
|
if self.variation_amount > 0:
|
||||||
|
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||||
|
seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||||
|
return (seed, initial_noise)
|
||||||
|
else:
|
||||||
|
return (seed, None)
|
||||||
|
|
||||||
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
|
def get_noise(self,width,height):
|
||||||
|
"""
|
||||||
|
Returns a tensor filled with random numbers, either form a normal distribution
|
||||||
|
(txt2img) or from the latent image (img2img, inpaint)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("get_noise() must be implemented in a descendent class")
|
||||||
|
|
||||||
|
def get_perlin_noise(self,width,height):
|
||||||
|
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||||
|
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||||
|
|
||||||
|
def new_seed(self):
|
||||||
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
return self.seed
|
||||||
|
|
||||||
|
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||||
|
'''
|
||||||
|
Spherical linear interpolation
|
||||||
|
Args:
|
||||||
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||||
|
v0 (np.ndarray): Starting vector
|
||||||
|
v1 (np.ndarray): Final vector
|
||||||
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||||
|
colineal. Not recommended to alter this.
|
||||||
|
Returns:
|
||||||
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||||
|
'''
|
||||||
|
inputs_are_torch = False
|
||||||
|
if not isinstance(v0, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v0 = v0.detach().cpu().numpy()
|
||||||
|
if not isinstance(v1, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v1 = v1.detach().cpu().numpy()
|
||||||
|
|
||||||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||||||
|
v2 = (1 - t) * v0 + t * v1
|
||||||
|
else:
|
||||||
|
theta_0 = np.arccos(dot)
|
||||||
|
sin_theta_0 = np.sin(theta_0)
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = np.sin(theta_t)
|
||||||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
v2 = s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
if inputs_are_torch:
|
||||||
|
v2 = torch.from_numpy(v2).to(self.model.device)
|
||||||
|
|
||||||
|
return v2
|
||||||
|
|
||||||
|
def safety_check(self,image:Image.Image):
|
||||||
|
'''
|
||||||
|
If the CompViz safety checker flags an NSFW image, we
|
||||||
|
blur it out.
|
||||||
|
'''
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
checker = self.safety_checker['checker']
|
||||||
|
extractor = self.safety_checker['extractor']
|
||||||
|
features = extractor([image], return_tensors="pt")
|
||||||
|
features.to(self.model.device)
|
||||||
|
|
||||||
|
# unfortunately checker requires the numpy version, so we have to convert back
|
||||||
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||||
|
if has_nsfw_concept[0]:
|
||||||
|
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||||
|
return self.blur(image)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def blur(self,input):
|
||||||
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
|
try:
|
||||||
|
caution = self.get_caution_img()
|
||||||
|
if caution:
|
||||||
|
blurry.paste(caution,(0,0),caution)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return blurry
|
||||||
|
|
||||||
|
def get_caution_img(self):
|
||||||
|
path = None
|
||||||
|
if self.caution_img:
|
||||||
|
return self.caution_img
|
||||||
|
# Find the caution image. If we are installed in the package directory it will
|
||||||
|
# be six levels up. If we are in the repo directory it will be three levels up.
|
||||||
|
for dots in ('../../..','../../../../../..'):
|
||||||
|
caution_path = osp.join(osp.dirname(__file__),dots,CAUTION_IMG)
|
||||||
|
if osp.exists(caution_path):
|
||||||
|
path = caution_path
|
||||||
|
break
|
||||||
|
if not path:
|
||||||
|
return
|
||||||
|
caution = Image.open(path)
|
||||||
|
self.caution_img = caution.resize((caution.width // 2, caution.height //2))
|
||||||
|
return self.caution_img
|
||||||
|
|
||||||
|
# this is a handy routine for debugging use. Given a generated sample,
|
||||||
|
# convert it into a PNG image and store it at the indicated path
|
||||||
|
def save_sample(self, sample, filepath):
|
||||||
|
image = self.sample_to_image(sample)
|
||||||
|
dirname = os.path.dirname(filepath) or '.'
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
print(f'** creating directory {dirname}')
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
image.save(filepath,'PNG')
|
||||||
|
|
||||||
|
|
501
ldm/invoke/ckpt_generator/embiggen.py
Normal file
501
ldm/invoke/ckpt_generator/embiggen.py
Normal file
@ -0,0 +1,501 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.ckpt_generator.embiggen descends from ldm.invoke.ckpt_generator
|
||||||
|
and generates with ldm.invoke.ckpt_generator.img2img
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import trange
|
||||||
|
from PIL import Image
|
||||||
|
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||||
|
from ldm.invoke.ckpt_generator.img2img import CkptImg2Img
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
class CkptEmbiggen(CkptGenerator):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
super().__init__(model, precision)
|
||||||
|
self.init_latent = None
|
||||||
|
|
||||||
|
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||||
|
def generate(self,prompt,iterations=1,seed=None,
|
||||||
|
image_callback=None, step_callback=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
make_image = self.get_make_image(
|
||||||
|
prompt,
|
||||||
|
step_callback = step_callback,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
seed = seed if seed else self.new_seed()
|
||||||
|
|
||||||
|
# Noise will be generated by the Img2Img generator when called
|
||||||
|
with scope(self.model.device.type), self.model.ema_scope():
|
||||||
|
for n in trange(iterations, desc='Generating'):
|
||||||
|
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||||
|
image = make_image()
|
||||||
|
results.append([image, seed])
|
||||||
|
if image_callback is not None:
|
||||||
|
image_callback(image, seed, prompt_in=prompt)
|
||||||
|
seed = self.new_seed()
|
||||||
|
return results
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_make_image(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
ddim_eta,
|
||||||
|
conditioning,
|
||||||
|
init_img,
|
||||||
|
strength,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
embiggen,
|
||||||
|
embiggen_tiles,
|
||||||
|
step_callback=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||||
|
Return value depends on the seed at the time you call it
|
||||||
|
"""
|
||||||
|
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
|
||||||
|
|
||||||
|
# Construct embiggen arg array, and sanity check arguments
|
||||||
|
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||||
|
embiggen = [1.0] # If not specified, assume no scaling
|
||||||
|
elif embiggen[0] < 0:
|
||||||
|
embiggen[0] = 1.0
|
||||||
|
print(
|
||||||
|
'>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
|
||||||
|
if len(embiggen) < 2:
|
||||||
|
embiggen.append(0.75)
|
||||||
|
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||||
|
embiggen[1] = 0.75
|
||||||
|
print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
|
||||||
|
if len(embiggen) < 3:
|
||||||
|
embiggen.append(0.25)
|
||||||
|
elif embiggen[2] < 0:
|
||||||
|
embiggen[2] = 0.25
|
||||||
|
print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')
|
||||||
|
|
||||||
|
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||||
|
# and then sort them, because... people.
|
||||||
|
if embiggen_tiles:
|
||||||
|
embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
|
||||||
|
embiggen_tiles.sort()
|
||||||
|
|
||||||
|
if strength >= 0.5:
|
||||||
|
print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')
|
||||||
|
|
||||||
|
# Prep img2img generator, since we wrap over it
|
||||||
|
gen_img2img = CkptImg2Img(self.model,self.precision)
|
||||||
|
|
||||||
|
# Open original init image (not a tensor) to manipulate
|
||||||
|
initsuperimage = Image.open(init_img)
|
||||||
|
|
||||||
|
with Image.open(init_img) as img:
|
||||||
|
initsuperimage = img.convert('RGB')
|
||||||
|
|
||||||
|
# Size of the target super init image in pixels
|
||||||
|
initsuperwidth, initsuperheight = initsuperimage.size
|
||||||
|
|
||||||
|
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||||
|
if embiggen[0] != 1.0:
|
||||||
|
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||||
|
initsuperheight = round(initsuperheight*embiggen[0])
|
||||||
|
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||||
|
from ldm.invoke.restoration.realesrgan import ESRGAN
|
||||||
|
esrgan = ESRGAN()
|
||||||
|
print(
|
||||||
|
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||||
|
if embiggen[0] > 2:
|
||||||
|
initsuperimage = esrgan.process(
|
||||||
|
initsuperimage,
|
||||||
|
embiggen[1], # upscale strength
|
||||||
|
self.seed,
|
||||||
|
4, # upscale scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
initsuperimage = esrgan.process(
|
||||||
|
initsuperimage,
|
||||||
|
embiggen[1], # upscale strength
|
||||||
|
self.seed,
|
||||||
|
2, # upscale scale
|
||||||
|
)
|
||||||
|
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||||
|
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||||
|
# Resize to target scaling factor resolution
|
||||||
|
initsuperimage = initsuperimage.resize(
|
||||||
|
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# Use width and height as tile widths and height
|
||||||
|
# Determine buffer size in pixels
|
||||||
|
if embiggen[2] < 1:
|
||||||
|
if embiggen[2] < 0:
|
||||||
|
embiggen[2] = 0
|
||||||
|
overlap_size_x = round(embiggen[2] * width)
|
||||||
|
overlap_size_y = round(embiggen[2] * height)
|
||||||
|
else:
|
||||||
|
overlap_size_x = round(embiggen[2])
|
||||||
|
overlap_size_y = round(embiggen[2])
|
||||||
|
|
||||||
|
# With overall image width and height known, determine how many tiles we need
|
||||||
|
def ceildiv(a, b):
|
||||||
|
return -1 * (-a // b)
|
||||||
|
|
||||||
|
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||||
|
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||||
|
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||||
|
emb_tiles_x = 1
|
||||||
|
emb_tiles_y = 1
|
||||||
|
if (initsuperwidth - width) > 0:
|
||||||
|
emb_tiles_x = ceildiv(initsuperwidth - width,
|
||||||
|
width - overlap_size_x) + 1
|
||||||
|
if (initsuperheight - height) > 0:
|
||||||
|
emb_tiles_y = ceildiv(initsuperheight - height,
|
||||||
|
height - overlap_size_y) + 1
|
||||||
|
# Sanity
|
||||||
|
assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'
|
||||||
|
|
||||||
|
# Prep alpha layers --------------
|
||||||
|
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||||
|
# agradientL is Left-side transparent
|
||||||
|
agradientL = Image.linear_gradient('L').rotate(
|
||||||
|
90).resize((overlap_size_x, height))
|
||||||
|
# agradientT is Top-side transparent
|
||||||
|
agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
|
||||||
|
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||||
|
agradientC = Image.new('L', (256, 256))
|
||||||
|
for y in range(256):
|
||||||
|
for x in range(256):
|
||||||
|
# Find distance to lower right corner (numpy takes arrays)
|
||||||
|
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||||
|
# Clamp values to max 255
|
||||||
|
if distanceToLR > 255:
|
||||||
|
distanceToLR = 255
|
||||||
|
#Place the pixel as invert of distance
|
||||||
|
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||||
|
|
||||||
|
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||||
|
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||||
|
agradientAsymC = Image.new('L', (256, 256))
|
||||||
|
for y in range(256):
|
||||||
|
for x in range(256):
|
||||||
|
value = round(max(0, x-(255-y)) * (255 / max(1,y)))
|
||||||
|
#Clamp values
|
||||||
|
value = max(0, value)
|
||||||
|
value = min(255, value)
|
||||||
|
agradientAsymC.putpixel((x, y), value)
|
||||||
|
|
||||||
|
# Create alpha layers default fully white
|
||||||
|
alphaLayerL = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerT = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||||
|
# Paste gradients into alpha layers
|
||||||
|
alphaLayerL.paste(agradientL, (0, 0))
|
||||||
|
alphaLayerT.paste(agradientT, (0, 0))
|
||||||
|
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||||
|
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||||
|
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||||
|
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||||
|
alphaLayerTaC = alphaLayerT.copy()
|
||||||
|
alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||||
|
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||||
|
alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
|
if embiggen_tiles:
|
||||||
|
# Individual unconnected sides
|
||||||
|
alphaLayerR = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerR.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
|
alphaLayerB = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerB.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerTB.paste(agradientT, (0, 0))
|
||||||
|
alphaLayerTB.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerLR.paste(agradientL, (0, 0))
|
||||||
|
alphaLayerLR.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
|
# Sides and corner Layers
|
||||||
|
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerRBC.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
|
alphaLayerRBC.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerRBC.paste(agradientC.rotate(180).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||||
|
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||||
|
alphaLayerLBC.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerLBC.paste(agradientC.rotate(90).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
|
||||||
|
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerRTC.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
|
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||||
|
alphaLayerRTC.paste(agradientC.rotate(270).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
|
# All but X layers
|
||||||
|
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||||
|
alphaLayerABT.paste(agradientL.rotate(
|
||||||
|
180), (width - overlap_size_x, 0))
|
||||||
|
alphaLayerABT.paste(agradientC.rotate(180).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||||
|
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||||
|
alphaLayerABL.paste(agradientT.rotate(
|
||||||
|
180), (0, height - overlap_size_y))
|
||||||
|
alphaLayerABL.paste(agradientC.rotate(180).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
|
||||||
|
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||||
|
alphaLayerABR.paste(agradientT, (0, 0))
|
||||||
|
alphaLayerABR.paste(agradientC.resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||||
|
alphaLayerABB.paste(agradientL, (0, 0))
|
||||||
|
alphaLayerABB.paste(agradientC.resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
|
||||||
|
# All-around layer
|
||||||
|
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||||
|
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||||
|
alphaLayerAA.paste(agradientT, (0, 0))
|
||||||
|
alphaLayerAA.paste(agradientC.resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (0, 0))
|
||||||
|
alphaLayerAA.paste(agradientC.rotate(270).resize(
|
||||||
|
(overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
|
||||||
|
|
||||||
|
# Clean up temporary gradients
|
||||||
|
del agradientL
|
||||||
|
del agradientT
|
||||||
|
del agradientC
|
||||||
|
|
||||||
|
def make_image():
|
||||||
|
# Make main tiles -------------------------------------------------
|
||||||
|
if embiggen_tiles:
|
||||||
|
print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')
|
||||||
|
|
||||||
|
emb_tile_store = []
|
||||||
|
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||||
|
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||||
|
# instead track and iterate a local seed we pass to Img2Img
|
||||||
|
seed = self.seed
|
||||||
|
seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy
|
||||||
|
|
||||||
|
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||||
|
# Don't iterate on first tile
|
||||||
|
if tile != 0:
|
||||||
|
if seed < seedintlimit:
|
||||||
|
seed += 1
|
||||||
|
else:
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
# Determine if this is a re-run and replace
|
||||||
|
if embiggen_tiles and not tile in embiggen_tiles:
|
||||||
|
continue
|
||||||
|
# Get row and column entries
|
||||||
|
emb_row_i = tile // emb_tiles_x
|
||||||
|
emb_column_i = tile % emb_tiles_x
|
||||||
|
# Determine bounds to cut up the init image
|
||||||
|
# Determine upper-left point
|
||||||
|
if emb_column_i + 1 == emb_tiles_x:
|
||||||
|
left = initsuperwidth - width
|
||||||
|
else:
|
||||||
|
left = round(emb_column_i * (width - overlap_size_x))
|
||||||
|
if emb_row_i + 1 == emb_tiles_y:
|
||||||
|
top = initsuperheight - height
|
||||||
|
else:
|
||||||
|
top = round(emb_row_i * (height - overlap_size_y))
|
||||||
|
right = left + width
|
||||||
|
bottom = top + height
|
||||||
|
|
||||||
|
# Cropped image of above dimension (does not modify the original)
|
||||||
|
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||||
|
# DEBUG:
|
||||||
|
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||||
|
# newinitimage.save(newinitimagepath)
|
||||||
|
|
||||||
|
if embiggen_tiles:
|
||||||
|
print(
|
||||||
|
f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')
|
||||||
|
|
||||||
|
# create a torch tensor from an Image
|
||||||
|
newinitimage = np.array(
|
||||||
|
newinitimage).astype(np.float32) / 255.0
|
||||||
|
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||||
|
newinitimage = torch.from_numpy(newinitimage)
|
||||||
|
newinitimage = 2.0 * newinitimage - 1.0
|
||||||
|
newinitimage = newinitimage.to(self.model.device)
|
||||||
|
|
||||||
|
tile_results = gen_img2img.generate(
|
||||||
|
prompt,
|
||||||
|
iterations = 1,
|
||||||
|
seed = seed,
|
||||||
|
sampler = DDIMSampler(self.model, device=self.model.device),
|
||||||
|
steps = steps,
|
||||||
|
cfg_scale = cfg_scale,
|
||||||
|
conditioning = conditioning,
|
||||||
|
ddim_eta = ddim_eta,
|
||||||
|
image_callback = None, # called only after the final image is generated
|
||||||
|
step_callback = step_callback, # called after each intermediate image is generated
|
||||||
|
width = width,
|
||||||
|
height = height,
|
||||||
|
init_image = newinitimage, # notice that init_image is different from init_img
|
||||||
|
mask_image = None,
|
||||||
|
strength = strength,
|
||||||
|
)
|
||||||
|
|
||||||
|
emb_tile_store.append(tile_results[0][0])
|
||||||
|
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||||
|
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||||
|
del newinitimage
|
||||||
|
|
||||||
|
# Sanity check we have them all
|
||||||
|
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
|
||||||
|
outputsuperimage = Image.new(
|
||||||
|
"RGBA", (initsuperwidth, initsuperheight))
|
||||||
|
if embiggen_tiles:
|
||||||
|
outputsuperimage.alpha_composite(
|
||||||
|
initsuperimage.convert('RGBA'), (0, 0))
|
||||||
|
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||||
|
if embiggen_tiles:
|
||||||
|
if tile in embiggen_tiles:
|
||||||
|
intileimage = emb_tile_store.pop(0)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
intileimage = emb_tile_store[tile]
|
||||||
|
intileimage = intileimage.convert('RGBA')
|
||||||
|
# Get row and column entries
|
||||||
|
emb_row_i = tile // emb_tiles_x
|
||||||
|
emb_column_i = tile % emb_tiles_x
|
||||||
|
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||||
|
left = 0
|
||||||
|
top = 0
|
||||||
|
else:
|
||||||
|
# Determine upper-left point
|
||||||
|
if emb_column_i + 1 == emb_tiles_x:
|
||||||
|
left = initsuperwidth - width
|
||||||
|
else:
|
||||||
|
left = round(emb_column_i *
|
||||||
|
(width - overlap_size_x))
|
||||||
|
if emb_row_i + 1 == emb_tiles_y:
|
||||||
|
top = initsuperheight - height
|
||||||
|
else:
|
||||||
|
top = round(emb_row_i * (height - overlap_size_y))
|
||||||
|
# Handle gradients for various conditions
|
||||||
|
# Handle emb_rerun case
|
||||||
|
if embiggen_tiles:
|
||||||
|
# top of image
|
||||||
|
if emb_row_i == 0:
|
||||||
|
if emb_column_i == 0:
|
||||||
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
|
if (tile+emb_tiles_x) not in embiggen_tiles: # Look-ahead down
|
||||||
|
intileimage.putalpha(alphaLayerB)
|
||||||
|
# Otherwise do nothing on this tile
|
||||||
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
|
intileimage.putalpha(alphaLayerR)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerRBC)
|
||||||
|
elif emb_column_i == emb_tiles_x - 1:
|
||||||
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
|
intileimage.putalpha(alphaLayerL)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerLBC)
|
||||||
|
else:
|
||||||
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
|
intileimage.putalpha(alphaLayerL)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerLBC)
|
||||||
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
|
intileimage.putalpha(alphaLayerLR)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerABT)
|
||||||
|
# bottom of image
|
||||||
|
elif emb_row_i == emb_tiles_y - 1:
|
||||||
|
if emb_column_i == 0:
|
||||||
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
|
intileimage.putalpha(alphaLayerTaC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerRTC)
|
||||||
|
elif emb_column_i == emb_tiles_x - 1:
|
||||||
|
# No tiles to look ahead to
|
||||||
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
|
else:
|
||||||
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
|
intileimage.putalpha(alphaLayerLTaC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerABB)
|
||||||
|
# vertical middle of image
|
||||||
|
else:
|
||||||
|
if emb_column_i == 0:
|
||||||
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
|
intileimage.putalpha(alphaLayerTaC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerTB)
|
||||||
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
|
intileimage.putalpha(alphaLayerRTC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerABL)
|
||||||
|
elif emb_column_i == emb_tiles_x - 1:
|
||||||
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerABR)
|
||||||
|
else:
|
||||||
|
if (tile+1) in embiggen_tiles: # Look-ahead right
|
||||||
|
if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
|
||||||
|
intileimage.putalpha(alphaLayerLTaC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerABR)
|
||||||
|
elif (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down only
|
||||||
|
intileimage.putalpha(alphaLayerABB)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerAA)
|
||||||
|
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||||
|
else:
|
||||||
|
if emb_row_i == 0 and emb_column_i >= 1:
|
||||||
|
intileimage.putalpha(alphaLayerL)
|
||||||
|
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||||
|
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||||
|
intileimage.putalpha(alphaLayerT)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerTaC)
|
||||||
|
else:
|
||||||
|
if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
|
||||||
|
intileimage.putalpha(alphaLayerLTC)
|
||||||
|
else:
|
||||||
|
intileimage.putalpha(alphaLayerLTaC)
|
||||||
|
# Layer tile onto final image
|
||||||
|
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||||
|
else:
|
||||||
|
print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||||
|
|
||||||
|
# after internal loops and patching up return Embiggen image
|
||||||
|
return outputsuperimage
|
||||||
|
# end of function declaration
|
||||||
|
return make_image
|
97
ldm/invoke/ckpt_generator/img2img.py
Normal file
97
ldm/invoke/ckpt_generator/img2img.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.ckpt_generator.img2img descends from ldm.invoke.generator
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from torch import Tensor
|
||||||
|
from PIL import Image
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
|
||||||
|
class CkptImg2Img(CkptGenerator):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
super().__init__(model, precision)
|
||||||
|
self.init_latent = None # by get_noise()
|
||||||
|
|
||||||
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
|
Return value depends on the seed at the time you call it.
|
||||||
|
"""
|
||||||
|
self.perlin = perlin
|
||||||
|
|
||||||
|
sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
||||||
|
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
|
self.model.encode_first_stage(init_image)
|
||||||
|
) # move to latent space
|
||||||
|
|
||||||
|
t_enc = int(strength * steps)
|
||||||
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
|
||||||
|
def make_image(x_T):
|
||||||
|
# encode (scaled latent)
|
||||||
|
z_enc = sampler.stochastic_encode(
|
||||||
|
self.init_latent,
|
||||||
|
torch.tensor([t_enc - 1]).to(self.model.device),
|
||||||
|
noise=x_T
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
|
self.model.model.to(self.model.device)
|
||||||
|
|
||||||
|
# decode it
|
||||||
|
samples = sampler.decode(
|
||||||
|
z_enc,
|
||||||
|
c,
|
||||||
|
t_enc,
|
||||||
|
img_callback = step_callback,
|
||||||
|
unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||||
|
extra_conditioning_info = extra_conditioning_info,
|
||||||
|
all_timesteps_count = steps
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.free_gpu_mem:
|
||||||
|
self.model.model.to("cpu")
|
||||||
|
|
||||||
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
|
return make_image
|
||||||
|
|
||||||
|
def get_noise(self,width,height):
|
||||||
|
device = self.model.device
|
||||||
|
init_latent = self.init_latent
|
||||||
|
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
||||||
|
if device.type == 'mps':
|
||||||
|
x = torch.randn_like(init_latent, device='cpu').to(device)
|
||||||
|
else:
|
||||||
|
x = torch.randn_like(init_latent, device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
shape = init_latent.shape
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
if len(image.shape) == 2: # 'L' image, as in a mask
|
||||||
|
image = image[None,None]
|
||||||
|
else: # 'RGB' image
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
if normalize:
|
||||||
|
image = 2.0 * image - 1.0
|
||||||
|
return image.to(self.model.device)
|
358
ldm/invoke/ckpt_generator/inpaint.py
Normal file
358
ldm/invoke/ckpt_generator/inpaint.py
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.ckpt_generator.inpaint descends from ldm.invoke.ckpt_generator
|
||||||
|
'''
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import numpy as np
|
||||||
|
import cv2 as cv
|
||||||
|
import PIL
|
||||||
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||||
|
from skimage.exposure.histogram_matching import match_histograms
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.invoke.ckpt_generator.img2img import CkptImg2Img
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
|
from ldm.invoke.generator.base import downsampling
|
||||||
|
from ldm.util import debug_image
|
||||||
|
from ldm.invoke.patchmatch import PatchMatch
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
|
def infill_methods()->list[str]:
|
||||||
|
methods = list()
|
||||||
|
if PatchMatch.patchmatch_available():
|
||||||
|
methods.append('patchmatch')
|
||||||
|
methods.append('tile')
|
||||||
|
return methods
|
||||||
|
|
||||||
|
class CkptInpaint(CkptImg2Img):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
self.init_latent = None
|
||||||
|
self.pil_image = None
|
||||||
|
self.pil_mask = None
|
||||||
|
self.mask_blur_radius = 0
|
||||||
|
self.infill_method = None
|
||||||
|
super().__init__(model, precision)
|
||||||
|
|
||||||
|
# Outpaint support code
|
||||||
|
def get_tile_images(self, image: np.ndarray, width=8, height=8):
|
||||||
|
_nrows, _ncols, depth = image.shape
|
||||||
|
_strides = image.strides
|
||||||
|
|
||||||
|
nrows, _m = divmod(_nrows, height)
|
||||||
|
ncols, _n = divmod(_ncols, width)
|
||||||
|
if _m != 0 or _n != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return np.lib.stride_tricks.as_strided(
|
||||||
|
np.ravel(image),
|
||||||
|
shape=(nrows, ncols, height, width, depth),
|
||||||
|
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||||
|
writeable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||||
|
if im.mode != 'RGBA':
|
||||||
|
return im
|
||||||
|
|
||||||
|
# Skip patchmatch if patchmatch isn't available
|
||||||
|
if not PatchMatch.patchmatch_available():
|
||||||
|
return im
|
||||||
|
|
||||||
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
|
im_patched_np = PatchMatch.inpaint(im.convert('RGB'), ImageOps.invert(im.split()[-1]), patch_size = 3)
|
||||||
|
im_patched = Image.fromarray(im_patched_np, mode = 'RGB')
|
||||||
|
return im_patched
|
||||||
|
|
||||||
|
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
|
||||||
|
# Only fill if there's an alpha layer
|
||||||
|
if im.mode != 'RGBA':
|
||||||
|
return im
|
||||||
|
|
||||||
|
a = np.asarray(im, dtype=np.uint8)
|
||||||
|
|
||||||
|
tile_size = (tile_size, tile_size)
|
||||||
|
|
||||||
|
# Get the image as tiles of a specified size
|
||||||
|
tiles = self.get_tile_images(a,*tile_size).copy()
|
||||||
|
|
||||||
|
# Get the mask as tiles
|
||||||
|
tiles_mask = tiles[:,:,:,:,3]
|
||||||
|
|
||||||
|
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||||
|
tmask_shape = tiles_mask.shape
|
||||||
|
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||||
|
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||||
|
tiles_mask = (tiles_mask > 0)
|
||||||
|
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)
|
||||||
|
|
||||||
|
# Get RGB tiles in single array and filter by the mask
|
||||||
|
tshape = tiles.shape
|
||||||
|
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
|
||||||
|
filtered_tiles = tiles_all[tiles_mask]
|
||||||
|
|
||||||
|
if len(filtered_tiles) == 0:
|
||||||
|
return im
|
||||||
|
|
||||||
|
# Find all invalid tiles and replace with a random valid tile
|
||||||
|
replace_count = (tiles_mask == False).sum()
|
||||||
|
rng = np.random.default_rng(seed = seed)
|
||||||
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]
|
||||||
|
|
||||||
|
# Convert back to an image
|
||||||
|
tiles_all = tiles_all.reshape(tshape)
|
||||||
|
tiles_all = tiles_all.swapaxes(1,2)
|
||||||
|
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
|
||||||
|
si = Image.fromarray(st, mode='RGBA')
|
||||||
|
|
||||||
|
return si
|
||||||
|
|
||||||
|
|
||||||
|
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||||
|
npimg = np.asarray(mask, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Detect any partially transparent regions
|
||||||
|
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||||
|
|
||||||
|
# Detect hard edges
|
||||||
|
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
||||||
|
|
||||||
|
# Combine
|
||||||
|
npmask = npgradient + npedge
|
||||||
|
|
||||||
|
# Expand
|
||||||
|
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||||
|
|
||||||
|
new_mask = Image.fromarray(npmask)
|
||||||
|
|
||||||
|
if edge_blur > 0:
|
||||||
|
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))
|
||||||
|
|
||||||
|
return ImageOps.invert(new_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def seam_paint(self,
|
||||||
|
im: Image.Image,
|
||||||
|
seam_size: int,
|
||||||
|
seam_blur: int,
|
||||||
|
prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
conditioning,strength,
|
||||||
|
noise,
|
||||||
|
step_callback
|
||||||
|
) -> Image.Image:
|
||||||
|
hard_mask = self.pil_image.split()[-1].copy()
|
||||||
|
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||||
|
|
||||||
|
make_image = self.get_make_image(
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
ddim_eta,
|
||||||
|
conditioning,
|
||||||
|
init_image = im.copy().convert('RGBA'),
|
||||||
|
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
||||||
|
strength = strength,
|
||||||
|
mask_blur_radius = 0,
|
||||||
|
seam_size = 0,
|
||||||
|
step_callback = step_callback,
|
||||||
|
inpaint_width = im.width,
|
||||||
|
inpaint_height = im.height
|
||||||
|
)
|
||||||
|
|
||||||
|
seam_noise = self.get_noise(im.width, im.height)
|
||||||
|
|
||||||
|
result = make_image(seam_noise)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
conditioning,init_image,mask_image,strength,
|
||||||
|
mask_blur_radius: int = 8,
|
||||||
|
# Seam settings - when 0, doesn't fill seam
|
||||||
|
seam_size: int = 0,
|
||||||
|
seam_blur: int = 0,
|
||||||
|
seam_strength: float = 0.7,
|
||||||
|
seam_steps: int = 10,
|
||||||
|
tile_size: int = 32,
|
||||||
|
step_callback=None,
|
||||||
|
inpaint_replace=False, enable_image_debugging=False,
|
||||||
|
infill_method = None,
|
||||||
|
inpaint_width=None,
|
||||||
|
inpaint_height=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and
|
||||||
|
the initial image + mask. Return value depends on the seed at
|
||||||
|
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.enable_image_debugging = enable_image_debugging
|
||||||
|
self.infill_method = infill_method or infill_methods()[0], # The infill method to use
|
||||||
|
|
||||||
|
self.inpaint_width = inpaint_width
|
||||||
|
self.inpaint_height = inpaint_height
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
self.pil_image = init_image.copy()
|
||||||
|
|
||||||
|
# Do infill
|
||||||
|
if infill_method == 'patchmatch' and PatchMatch.patchmatch_available():
|
||||||
|
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||||
|
else: # if infill_method == 'tile': # Only two methods right now, so always use 'tile' if not patchmatch
|
||||||
|
init_filled = self.tile_fill_missing(
|
||||||
|
self.pil_image.copy(),
|
||||||
|
seed = self.seed,
|
||||||
|
tile_size = tile_size
|
||||||
|
)
|
||||||
|
init_filled.paste(init_image, (0,0), init_image.split()[-1])
|
||||||
|
|
||||||
|
# Resize if requested for inpainting
|
||||||
|
if inpaint_width and inpaint_height:
|
||||||
|
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||||
|
|
||||||
|
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
|
# Create init tensor
|
||||||
|
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
||||||
|
|
||||||
|
if isinstance(mask_image, PIL.Image.Image):
|
||||||
|
self.pil_mask = mask_image.copy()
|
||||||
|
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
|
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
|
||||||
|
self.pil_mask = mask_image
|
||||||
|
|
||||||
|
# Resize if requested for inpainting
|
||||||
|
if inpaint_width and inpaint_height:
|
||||||
|
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||||
|
|
||||||
|
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||||
|
mask_image = mask_image.resize(
|
||||||
|
(
|
||||||
|
mask_image.width // downsampling,
|
||||||
|
mask_image.height // downsampling
|
||||||
|
),
|
||||||
|
resample=Image.Resampling.NEAREST
|
||||||
|
)
|
||||||
|
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||||
|
|
||||||
|
self.mask_blur_radius = mask_blur_radius
|
||||||
|
|
||||||
|
# klms samplers not supported yet, so ignore previous sampler
|
||||||
|
if isinstance(sampler,KSampler):
|
||||||
|
print(
|
||||||
|
f">> Using recommended DDIM sampler for inpainting."
|
||||||
|
)
|
||||||
|
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||||
|
|
||||||
|
sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
||||||
|
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
||||||
|
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
|
self.model.encode_first_stage(init_image)
|
||||||
|
) # move to latent space
|
||||||
|
|
||||||
|
t_enc = int(strength * steps)
|
||||||
|
# todo: support cross-attention control
|
||||||
|
uc, c, _ = conditioning
|
||||||
|
|
||||||
|
print(f">> target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def make_image(x_T):
|
||||||
|
# encode (scaled latent)
|
||||||
|
z_enc = sampler.stochastic_encode(
|
||||||
|
self.init_latent,
|
||||||
|
torch.tensor([t_enc - 1]).to(self.model.device),
|
||||||
|
noise=x_T
|
||||||
|
)
|
||||||
|
|
||||||
|
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
||||||
|
if inpaint_replace > 0.0:
|
||||||
|
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
||||||
|
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
||||||
|
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
||||||
|
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
||||||
|
z_enc = z_enc * mask_image + masked_region
|
||||||
|
|
||||||
|
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
|
self.model.model.to(self.model.device)
|
||||||
|
|
||||||
|
# decode it
|
||||||
|
samples = sampler.decode(
|
||||||
|
z_enc,
|
||||||
|
c,
|
||||||
|
t_enc,
|
||||||
|
img_callback = step_callback,
|
||||||
|
unconditional_guidance_scale = cfg_scale,
|
||||||
|
unconditional_conditioning = uc,
|
||||||
|
mask = mask_image,
|
||||||
|
init_latent = self.init_latent
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.sample_to_image(samples)
|
||||||
|
|
||||||
|
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||||
|
if seam_size > 0:
|
||||||
|
old_image = self.pil_image or init_image
|
||||||
|
old_mask = self.pil_mask or mask_image
|
||||||
|
|
||||||
|
result = self.seam_paint(
|
||||||
|
result,
|
||||||
|
seam_size,
|
||||||
|
seam_blur,
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
seam_steps,
|
||||||
|
cfg_scale,
|
||||||
|
ddim_eta,
|
||||||
|
conditioning,
|
||||||
|
seam_strength,
|
||||||
|
x_T,
|
||||||
|
step_callback)
|
||||||
|
|
||||||
|
# Restore original settings
|
||||||
|
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
conditioning,
|
||||||
|
old_image,
|
||||||
|
old_mask,
|
||||||
|
strength,
|
||||||
|
mask_blur_radius, seam_size, seam_blur, seam_strength,
|
||||||
|
seam_steps, tile_size, step_callback,
|
||||||
|
inpaint_replace, enable_image_debugging,
|
||||||
|
inpaint_width = inpaint_width,
|
||||||
|
inpaint_height = inpaint_height,
|
||||||
|
infill_method = infill_method,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return make_image
|
||||||
|
|
||||||
|
|
||||||
|
def sample_to_image(self, samples)->Image.Image:
|
||||||
|
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||||
|
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
|
# Resize if necessary
|
||||||
|
if self.inpaint_width and self.inpaint_height:
|
||||||
|
gen_result = gen_result.resize(self.pil_image.size)
|
||||||
|
|
||||||
|
if self.pil_image is None or self.pil_mask is None:
|
||||||
|
return gen_result
|
||||||
|
|
||||||
|
corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||||
|
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
|
return corrected_result
|
175
ldm/invoke/ckpt_generator/omnibus.py
Normal file
175
ldm/invoke/ckpt_generator/omnibus.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from einops import repeat
|
||||||
|
from PIL import Image, ImageOps, ImageChops
|
||||||
|
from ldm.invoke.devices import choose_autocast
|
||||||
|
from ldm.invoke.ckpt_generator.base import downsampling
|
||||||
|
from ldm.invoke.ckpt_generator.img2img import CkptImg2Img
|
||||||
|
from ldm.invoke.ckpt_generator.txt2img import CkptTxt2Img
|
||||||
|
|
||||||
|
class CkptOmnibus(CkptImg2Img,CkptTxt2Img):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
super().__init__(model, precision)
|
||||||
|
self.pil_mask = None
|
||||||
|
self.pil_image = None
|
||||||
|
|
||||||
|
def get_make_image(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
ddim_eta,
|
||||||
|
conditioning,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
init_image = None,
|
||||||
|
mask_image = None,
|
||||||
|
strength = None,
|
||||||
|
step_callback=None,
|
||||||
|
threshold=0.0,
|
||||||
|
perlin=0.0,
|
||||||
|
mask_blur_radius: int = 8,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
|
Return value depends on the seed at the time you call it.
|
||||||
|
"""
|
||||||
|
self.perlin = perlin
|
||||||
|
num_samples = 1
|
||||||
|
|
||||||
|
sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(init_image, Image.Image):
|
||||||
|
self.pil_image = init_image
|
||||||
|
if init_image.mode != 'RGB':
|
||||||
|
init_image = init_image.convert('RGB')
|
||||||
|
init_image = self._image_to_tensor(init_image)
|
||||||
|
|
||||||
|
if isinstance(mask_image, Image.Image):
|
||||||
|
self.pil_mask = mask_image
|
||||||
|
|
||||||
|
mask_image = ImageChops.multiply(mask_image.convert('L'), self.pil_image.split()[-1])
|
||||||
|
mask_image = self._image_to_tensor(ImageOps.invert(mask_image), normalize=False)
|
||||||
|
|
||||||
|
self.mask_blur_radius = mask_blur_radius
|
||||||
|
|
||||||
|
t_enc = steps
|
||||||
|
|
||||||
|
if init_image is not None and mask_image is not None: # inpainting
|
||||||
|
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||||
|
|
||||||
|
elif init_image is not None: # img2img
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
|
self.model.encode_first_stage(init_image)
|
||||||
|
) # move to latent space
|
||||||
|
|
||||||
|
# create a completely black mask (1s)
|
||||||
|
mask_image = torch.ones(1, 1, init_image.shape[2], init_image.shape[3], device=self.model.device)
|
||||||
|
# and the masked image is just a copy of the original
|
||||||
|
masked_image = init_image
|
||||||
|
|
||||||
|
else: # txt2img
|
||||||
|
init_image = torch.zeros(1, 3, height, width, device=self.model.device)
|
||||||
|
mask_image = torch.ones(1, 1, height, width, device=self.model.device)
|
||||||
|
masked_image = init_image
|
||||||
|
|
||||||
|
self.init_latent = init_image
|
||||||
|
height = init_image.shape[2]
|
||||||
|
width = init_image.shape[3]
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
def make_image(x_T):
|
||||||
|
with torch.no_grad():
|
||||||
|
scope = choose_autocast(self.precision)
|
||||||
|
with scope(self.model.device.type):
|
||||||
|
|
||||||
|
batch = self.make_batch_sd(
|
||||||
|
init_image,
|
||||||
|
mask_image,
|
||||||
|
masked_image,
|
||||||
|
prompt=prompt,
|
||||||
|
device=model.device,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
c_cat = list()
|
||||||
|
for ck in model.concat_keys:
|
||||||
|
cc = batch[ck].float()
|
||||||
|
if ck != model.masked_image_key:
|
||||||
|
bchw = [num_samples, 4, height//8, width//8]
|
||||||
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
|
else:
|
||||||
|
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
|
||||||
|
# cond
|
||||||
|
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
|
||||||
|
# uncond cond
|
||||||
|
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||||
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||||
|
shape = [model.channels, height//8, width//8]
|
||||||
|
|
||||||
|
samples, _ = sampler.sample(
|
||||||
|
batch_size = 1,
|
||||||
|
S = steps,
|
||||||
|
x_T = x_T,
|
||||||
|
conditioning = cond,
|
||||||
|
shape = shape,
|
||||||
|
verbose = False,
|
||||||
|
unconditional_guidance_scale = cfg_scale,
|
||||||
|
unconditional_conditioning = uc_full,
|
||||||
|
eta = 1.0,
|
||||||
|
img_callback = step_callback,
|
||||||
|
threshold = threshold,
|
||||||
|
)
|
||||||
|
if self.free_gpu_mem:
|
||||||
|
self.model.model.to("cpu")
|
||||||
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
|
return make_image
|
||||||
|
|
||||||
|
def make_batch_sd(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
mask,
|
||||||
|
masked_image,
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_samples=1):
|
||||||
|
batch = {
|
||||||
|
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
|
"txt": num_samples * [prompt],
|
||||||
|
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
|
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def get_noise(self, width:int, height:int):
|
||||||
|
if self.init_latent is not None:
|
||||||
|
height = self.init_latent.shape[2]
|
||||||
|
width = self.init_latent.shape[3]
|
||||||
|
return CkptTxt2Img.get_noise(self,width,height)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_to_image(self, samples)->Image.Image:
|
||||||
|
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||||
|
|
||||||
|
if self.pil_image is None or self.pil_mask is None:
|
||||||
|
return gen_result
|
||||||
|
if self.pil_image.size != self.pil_mask.size:
|
||||||
|
return gen_result
|
||||||
|
|
||||||
|
corrected_result = super(CkptImg2Img, self).repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||||
|
|
||||||
|
return corrected_result
|
88
ldm/invoke/ckpt_generator/txt2img.py
Normal file
88
ldm/invoke/ckpt_generator/txt2img.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.ckpt_generator.txt2img inherits from ldm.invoke.ckpt_generator
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
class CkptTxt2Img(CkptGenerator):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
super().__init__(model, precision)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||||
|
attention_maps_callback=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
|
Return value depends on the seed at the time you call it
|
||||||
|
kwargs are 'width' and 'height'
|
||||||
|
"""
|
||||||
|
self.perlin = perlin
|
||||||
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def make_image(x_T):
|
||||||
|
shape = [
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor,
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
|
self.model.model.to(self.model.device)
|
||||||
|
|
||||||
|
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
samples, _ = sampler.sample(
|
||||||
|
batch_size = 1,
|
||||||
|
S = steps,
|
||||||
|
x_T = x_T,
|
||||||
|
conditioning = c,
|
||||||
|
shape = shape,
|
||||||
|
verbose = False,
|
||||||
|
unconditional_guidance_scale = cfg_scale,
|
||||||
|
unconditional_conditioning = uc,
|
||||||
|
extra_conditioning_info = extra_conditioning_info,
|
||||||
|
eta = ddim_eta,
|
||||||
|
img_callback = step_callback,
|
||||||
|
threshold = threshold,
|
||||||
|
attention_maps_callback = attention_maps_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.free_gpu_mem:
|
||||||
|
self.model.model.to('cpu')
|
||||||
|
self.model.cond_stage_model.device = 'cpu'
|
||||||
|
self.model.cond_stage_model.to('cpu')
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
|
return make_image
|
||||||
|
|
||||||
|
|
||||||
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
|
def get_noise(self,width,height):
|
||||||
|
device = self.model.device
|
||||||
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
|
x = torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
device='cpu').to(device)
|
||||||
|
else:
|
||||||
|
x = torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
|
return x
|
||||||
|
|
182
ldm/invoke/ckpt_generator/txt2img2img.py
Normal file
182
ldm/invoke/ckpt_generator/txt2img2img.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.ckpt_generator.txt2img inherits from ldm.invoke.ckpt_generator
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import gc
|
||||||
|
from ldm.invoke.ckpt_generator.base import CkptGenerator
|
||||||
|
from ldm.invoke.ckpt_generator.omnibus import CkptOmnibus
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
class CkptTxt2Img2Img(CkptGenerator):
|
||||||
|
def __init__(self, model, precision):
|
||||||
|
super().__init__(model, precision)
|
||||||
|
self.init_latent = None # for get_noise()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
conditioning,width,height,strength,step_callback=None,**kwargs):
|
||||||
|
"""
|
||||||
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
|
Return value depends on the seed at the time you call it
|
||||||
|
kwargs are 'width' and 'height'
|
||||||
|
"""
|
||||||
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
scale_dim = min(width, height)
|
||||||
|
scale = 512 / scale_dim
|
||||||
|
|
||||||
|
init_width = math.ceil(scale * width / 64) * 64
|
||||||
|
init_height = math.ceil(scale * height / 64) * 64
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def make_image(x_T):
|
||||||
|
|
||||||
|
shape = [
|
||||||
|
self.latent_channels,
|
||||||
|
init_height // self.downsampling_factor,
|
||||||
|
init_width // self.downsampling_factor,
|
||||||
|
]
|
||||||
|
|
||||||
|
sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
#x = self.get_noise(init_width, init_height)
|
||||||
|
x = x_T
|
||||||
|
|
||||||
|
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
||||||
|
self.model.model.to(self.model.device)
|
||||||
|
|
||||||
|
samples, _ = sampler.sample(
|
||||||
|
batch_size = 1,
|
||||||
|
S = steps,
|
||||||
|
x_T = x,
|
||||||
|
conditioning = c,
|
||||||
|
shape = shape,
|
||||||
|
verbose = False,
|
||||||
|
unconditional_guidance_scale = cfg_scale,
|
||||||
|
unconditional_conditioning = uc,
|
||||||
|
eta = ddim_eta,
|
||||||
|
img_callback = step_callback,
|
||||||
|
extra_conditioning_info = extra_conditioning_info
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||||
|
)
|
||||||
|
|
||||||
|
# resizing
|
||||||
|
samples = torch.nn.functional.interpolate(
|
||||||
|
samples,
|
||||||
|
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||||
|
mode="bilinear"
|
||||||
|
)
|
||||||
|
|
||||||
|
t_enc = int(strength * steps)
|
||||||
|
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||||
|
ddim_sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
z_enc = ddim_sampler.stochastic_encode(
|
||||||
|
samples,
|
||||||
|
torch.tensor([t_enc-1]).to(self.model.device),
|
||||||
|
noise=self.get_noise(width,height,False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# decode it
|
||||||
|
samples = ddim_sampler.decode(
|
||||||
|
z_enc,
|
||||||
|
c,
|
||||||
|
t_enc,
|
||||||
|
img_callback = step_callback,
|
||||||
|
unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
|
all_timesteps_count=steps
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.free_gpu_mem:
|
||||||
|
self.model.model.to('cpu')
|
||||||
|
self.model.cond_stage_model.device = 'cpu'
|
||||||
|
self.model.cond_stage_model.to('cpu')
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return self.sample_to_image(samples)
|
||||||
|
|
||||||
|
# in the case of the inpainting model being loaded, the trick of
|
||||||
|
# providing an interpolated latent doesn't work, so we transiently
|
||||||
|
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||||
|
# over it in img2img mode. Because the inpaing model is so conservative
|
||||||
|
# it doesn't change the image (much)
|
||||||
|
def inpaint_make_image(x_T):
|
||||||
|
omnibus = CkptOmnibus(self.model,self.precision)
|
||||||
|
result = omnibus.generate(
|
||||||
|
prompt,
|
||||||
|
sampler=sampler,
|
||||||
|
width=init_width,
|
||||||
|
height=init_height,
|
||||||
|
step_callback=step_callback,
|
||||||
|
steps = steps,
|
||||||
|
cfg_scale = cfg_scale,
|
||||||
|
ddim_eta = ddim_eta,
|
||||||
|
conditioning = conditioning,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
assert result is not None and len(result)>0,'** txt2img failed **'
|
||||||
|
image = result[0][0]
|
||||||
|
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
||||||
|
print(kwargs.pop('init_image',None))
|
||||||
|
result = omnibus.generate(
|
||||||
|
prompt,
|
||||||
|
sampler=sampler,
|
||||||
|
init_image=interpolated_image,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
seed=result[0][1],
|
||||||
|
step_callback=step_callback,
|
||||||
|
steps = steps,
|
||||||
|
cfg_scale = cfg_scale,
|
||||||
|
ddim_eta = ddim_eta,
|
||||||
|
conditioning = conditioning,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0][0]
|
||||||
|
|
||||||
|
if sampler.uses_inpainting_model():
|
||||||
|
return inpaint_make_image
|
||||||
|
else:
|
||||||
|
return make_image
|
||||||
|
|
||||||
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
|
def get_noise(self,width,height,scale = True):
|
||||||
|
# print(f"Get noise: {width}x{height}")
|
||||||
|
if scale:
|
||||||
|
trained_square = 512 * 512
|
||||||
|
actual_square = width * height
|
||||||
|
scale = math.sqrt(trained_square / actual_square)
|
||||||
|
scaled_width = math.ceil(scale * width / 64) * 64
|
||||||
|
scaled_height = math.ceil(scale * height / 64) * 64
|
||||||
|
else:
|
||||||
|
scaled_width = width
|
||||||
|
scaled_height = height
|
||||||
|
|
||||||
|
device = self.model.device
|
||||||
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
|
return torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
scaled_height // self.downsampling_factor,
|
||||||
|
scaled_width // self.downsampling_factor],
|
||||||
|
device='cpu').to(device)
|
||||||
|
else:
|
||||||
|
return torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
scaled_height // self.downsampling_factor,
|
||||||
|
scaled_width // self.downsampling_factor],
|
||||||
|
device=device)
|
||||||
|
|
953
ldm/invoke/ckpt_to_diffuser.py
Normal file
953
ldm/invoke/ckpt_to_diffuser.py
Normal file
@ -0,0 +1,953 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Adapted for use as a module by Lincoln Stein <lstein@gmail.com>
|
||||||
|
# Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
|
||||||
|
""" Conversion script for the LDM checkpoints. """
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
try:
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
DDIMScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
HeunDiscreteScheduler,
|
||||||
|
LDMTextToImagePipeline,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
StableDiffusionPipeline,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||||
|
|
||||||
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
|
"""
|
||||||
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
|
"""
|
||||||
|
if n_shave_prefix_segments >= 0:
|
||||||
|
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||||
|
else:
|
||||||
|
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||||
|
|
||||||
|
|
||||||
|
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item.replace("in_layers.0", "norm1")
|
||||||
|
new_item = new_item.replace("in_layers.2", "conv1")
|
||||||
|
|
||||||
|
new_item = new_item.replace("out_layers.0", "norm2")
|
||||||
|
new_item = new_item.replace("out_layers.3", "conv2")
|
||||||
|
|
||||||
|
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||||
|
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||||
|
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||||
|
|
||||||
|
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||||
|
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||||
|
|
||||||
|
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||||
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("q.weight", "query.weight")
|
||||||
|
new_item = new_item.replace("q.bias", "query.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("k.weight", "key.weight")
|
||||||
|
new_item = new_item.replace("k.bias", "key.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("v.weight", "value.weight")
|
||||||
|
new_item = new_item.replace("v.bias", "value.bias")
|
||||||
|
|
||||||
|
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||||
|
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def assign_to_checkpoint(
|
||||||
|
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||||
|
to them. It splits attention layers, and takes into account additional replacements
|
||||||
|
that may arise.
|
||||||
|
|
||||||
|
Assigns the weights to the new checkpoint.
|
||||||
|
"""
|
||||||
|
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||||
|
|
||||||
|
# Splits the attention layers into three variables.
|
||||||
|
if attention_paths_to_split is not None:
|
||||||
|
for path, path_map in attention_paths_to_split.items():
|
||||||
|
old_tensor = old_checkpoint[path]
|
||||||
|
channels = old_tensor.shape[0] // 3
|
||||||
|
|
||||||
|
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||||
|
|
||||||
|
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||||
|
|
||||||
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||||
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||||
|
|
||||||
|
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||||
|
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||||
|
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
new_path = path["new"]
|
||||||
|
|
||||||
|
# These have already been assigned
|
||||||
|
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Global renaming happens here
|
||||||
|
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||||
|
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||||
|
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
||||||
|
|
||||||
|
if additional_replacements is not None:
|
||||||
|
for replacement in additional_replacements:
|
||||||
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||||
|
|
||||||
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
|
if "proj_attn.weight" in new_path:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||||
|
else:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||||
|
|
||||||
|
|
||||||
|
def conv_attn_to_linear(checkpoint):
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||||
|
for key in keys:
|
||||||
|
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||||
|
elif "proj_attn.weight" in key:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def create_unet_diffusers_config(original_config, image_size: int):
|
||||||
|
"""
|
||||||
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
|
"""
|
||||||
|
unet_params = original_config.model.params.unet_config.params
|
||||||
|
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||||
|
|
||||||
|
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||||
|
|
||||||
|
down_block_types = []
|
||||||
|
resolution = 1
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||||
|
down_block_types.append(block_type)
|
||||||
|
if i != len(block_out_channels) - 1:
|
||||||
|
resolution *= 2
|
||||||
|
|
||||||
|
up_block_types = []
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||||
|
up_block_types.append(block_type)
|
||||||
|
resolution //= 2
|
||||||
|
|
||||||
|
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||||
|
|
||||||
|
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||||
|
use_linear_projection = (
|
||||||
|
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||||
|
)
|
||||||
|
if use_linear_projection:
|
||||||
|
# stable diffusion 2-base-512 and 2-768
|
||||||
|
if head_dim is None:
|
||||||
|
head_dim = [5, 10, 20, 20]
|
||||||
|
|
||||||
|
config = dict(
|
||||||
|
sample_size=image_size // vae_scale_factor,
|
||||||
|
in_channels=unet_params.in_channels,
|
||||||
|
out_channels=unet_params.out_channels,
|
||||||
|
down_block_types=tuple(down_block_types),
|
||||||
|
up_block_types=tuple(up_block_types),
|
||||||
|
block_out_channels=tuple(block_out_channels),
|
||||||
|
layers_per_block=unet_params.num_res_blocks,
|
||||||
|
cross_attention_dim=unet_params.context_dim,
|
||||||
|
attention_head_dim=head_dim,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_vae_diffusers_config(original_config, image_size: int):
|
||||||
|
"""
|
||||||
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
|
"""
|
||||||
|
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||||
|
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||||
|
|
||||||
|
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||||
|
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||||
|
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||||
|
|
||||||
|
config = dict(
|
||||||
|
sample_size=image_size,
|
||||||
|
in_channels=vae_params.in_channels,
|
||||||
|
out_channels=vae_params.out_ch,
|
||||||
|
down_block_types=tuple(down_block_types),
|
||||||
|
up_block_types=tuple(up_block_types),
|
||||||
|
block_out_channels=tuple(block_out_channels),
|
||||||
|
latent_channels=vae_params.z_channels,
|
||||||
|
layers_per_block=vae_params.num_res_blocks,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_diffusers_schedular(original_config):
|
||||||
|
schedular = DDIMScheduler(
|
||||||
|
num_train_timesteps=original_config.model.params.timesteps,
|
||||||
|
beta_start=original_config.model.params.linear_start,
|
||||||
|
beta_end=original_config.model.params.linear_end,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
)
|
||||||
|
return schedular
|
||||||
|
|
||||||
|
|
||||||
|
def create_ldm_bert_config(original_config):
|
||||||
|
bert_params = original_config.model.params.cond_stage_config.params
|
||||||
|
config = LDMBertConfig(
|
||||||
|
d_model=bert_params.n_embed,
|
||||||
|
encoder_layers=bert_params.n_layer,
|
||||||
|
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
||||||
|
"""
|
||||||
|
Takes a state dict and a config, and returns a converted checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# extract state_dict for UNet
|
||||||
|
unet_state_dict = {}
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
unet_key = "model.diffusion_model."
|
||||||
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
|
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
|
if extract_ema:
|
||||||
|
print(
|
||||||
|
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||||
|
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||||
|
)
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith("model.diffusion_model"):
|
||||||
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||||
|
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(unet_key):
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||||
|
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||||
|
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||||
|
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||||
|
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||||
|
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||||
|
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||||
|
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||||
|
|
||||||
|
# Retrieves the keys for the input blocks only
|
||||||
|
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||||
|
input_blocks = {
|
||||||
|
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_input_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieves the keys for the middle blocks only
|
||||||
|
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||||
|
middle_blocks = {
|
||||||
|
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_middle_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieves the keys for the output blocks only
|
||||||
|
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||||
|
output_blocks = {
|
||||||
|
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||||
|
for layer_id in range(num_output_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(1, num_input_blocks):
|
||||||
|
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||||
|
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||||
|
|
||||||
|
resnets = [
|
||||||
|
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||||
|
]
|
||||||
|
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||||
|
|
||||||
|
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||||
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||||
|
f"input_blocks.{i}.0.op.weight"
|
||||||
|
)
|
||||||
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||||
|
f"input_blocks.{i}.0.op.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = renew_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(attentions):
|
||||||
|
paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
resnet_0 = middle_blocks[0]
|
||||||
|
attentions = middle_blocks[1]
|
||||||
|
resnet_1 = middle_blocks[2]
|
||||||
|
|
||||||
|
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||||
|
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
||||||
|
|
||||||
|
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||||
|
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
||||||
|
|
||||||
|
attentions_paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_output_blocks):
|
||||||
|
block_id = i // (config["layers_per_block"] + 1)
|
||||||
|
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
||||||
|
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||||
|
output_block_list = {}
|
||||||
|
|
||||||
|
for layer in output_block_layers:
|
||||||
|
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||||
|
if layer_id in output_block_list:
|
||||||
|
output_block_list[layer_id].append(layer_name)
|
||||||
|
else:
|
||||||
|
output_block_list[layer_id] = [layer_name]
|
||||||
|
|
||||||
|
if len(output_block_list) > 1:
|
||||||
|
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||||
|
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||||
|
|
||||||
|
resnet_0_paths = renew_resnet_paths(resnets)
|
||||||
|
paths = renew_resnet_paths(resnets)
|
||||||
|
|
||||||
|
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||||
|
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||||
|
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||||
|
f"output_blocks.{i}.{index}.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||||
|
f"output_blocks.{i}.{index}.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Clear attentions as they have been attributed above.
|
||||||
|
if len(attentions) == 2:
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
if len(attentions):
|
||||||
|
paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {
|
||||||
|
"old": f"output_blocks.{i}.1",
|
||||||
|
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||||
|
for path in resnet_0_paths:
|
||||||
|
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||||
|
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
||||||
|
|
||||||
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||||
|
# extract state dict for VAE
|
||||||
|
vae_state_dict = {}
|
||||||
|
vae_key = "first_stage_model."
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(vae_key):
|
||||||
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||||
|
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||||
|
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||||
|
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||||
|
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||||
|
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||||
|
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||||
|
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||||
|
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||||
|
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||||
|
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||||
|
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||||
|
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||||
|
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||||
|
|
||||||
|
# Retrieves the keys for the encoder down blocks only
|
||||||
|
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||||
|
down_blocks = {
|
||||||
|
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrieves the keys for the decoder up blocks only
|
||||||
|
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||||
|
up_blocks = {
|
||||||
|
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||||
|
|
||||||
|
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||||
|
f"encoder.down.{i}.downsample.conv.weight"
|
||||||
|
)
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||||
|
f"encoder.down.{i}.downsample.conv.bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||||
|
num_mid_res_blocks = 2
|
||||||
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||||
|
paths = renew_vae_attention_paths(mid_attentions)
|
||||||
|
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
conv_attn_to_linear(new_checkpoint)
|
||||||
|
|
||||||
|
for i in range(num_up_blocks):
|
||||||
|
block_id = num_up_blocks - 1 - i
|
||||||
|
resnets = [
|
||||||
|
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||||
|
]
|
||||||
|
|
||||||
|
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||||
|
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||||
|
]
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||||
|
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||||
|
]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||||
|
num_mid_res_blocks = 2
|
||||||
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||||
|
paths = renew_vae_attention_paths(mid_attentions)
|
||||||
|
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
conv_attn_to_linear(new_checkpoint)
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||||
|
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||||
|
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||||
|
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||||
|
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||||
|
|
||||||
|
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||||
|
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||||
|
|
||||||
|
def _copy_linear(hf_linear, pt_linear):
|
||||||
|
hf_linear.weight = pt_linear.weight
|
||||||
|
hf_linear.bias = pt_linear.bias
|
||||||
|
|
||||||
|
def _copy_layer(hf_layer, pt_layer):
|
||||||
|
# copy layer norms
|
||||||
|
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||||
|
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||||
|
|
||||||
|
# copy attn
|
||||||
|
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||||
|
|
||||||
|
# copy MLP
|
||||||
|
pt_mlp = pt_layer[1][1]
|
||||||
|
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||||
|
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||||
|
|
||||||
|
def _copy_layers(hf_layers, pt_layers):
|
||||||
|
for i, hf_layer in enumerate(hf_layers):
|
||||||
|
if i != 0:
|
||||||
|
i += i
|
||||||
|
pt_layer = pt_layers[i : i + 2]
|
||||||
|
_copy_layer(hf_layer, pt_layer)
|
||||||
|
|
||||||
|
hf_model = LDMBertModel(config).eval()
|
||||||
|
|
||||||
|
# copy embeds
|
||||||
|
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||||
|
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||||
|
|
||||||
|
# copy layer norm
|
||||||
|
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||||
|
|
||||||
|
# copy hidden layers
|
||||||
|
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||||
|
|
||||||
|
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||||
|
|
||||||
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
text_model_dict = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
|
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||||
|
|
||||||
|
text_model.load_state_dict(text_model_dict)
|
||||||
|
|
||||||
|
return text_model
|
||||||
|
|
||||||
|
|
||||||
|
textenc_conversion_lst = [
|
||||||
|
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||||
|
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||||
|
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||||
|
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||||
|
]
|
||||||
|
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||||
|
|
||||||
|
textenc_transformer_conversion_lst = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("resblocks.", "text_model.encoder.layers."),
|
||||||
|
("ln_1", "layer_norm1"),
|
||||||
|
("ln_2", "layer_norm2"),
|
||||||
|
(".c_fc.", ".fc1."),
|
||||||
|
(".c_proj.", ".fc2."),
|
||||||
|
(".attn", ".self_attn"),
|
||||||
|
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||||
|
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||||
|
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||||
|
]
|
||||||
|
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||||
|
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_paint_by_example_checkpoint(checkpoint):
|
||||||
|
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
model = PaintByExampleImageEncoder(config)
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
text_model_dict = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
|
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||||
|
|
||||||
|
# load clip vision
|
||||||
|
model.model.load_state_dict(text_model_dict)
|
||||||
|
|
||||||
|
# load mapper
|
||||||
|
keys_mapper = {
|
||||||
|
k[len("cond_stage_model.mapper.res") :]: v
|
||||||
|
for k, v in checkpoint.items()
|
||||||
|
if k.startswith("cond_stage_model.mapper")
|
||||||
|
}
|
||||||
|
|
||||||
|
MAPPING = {
|
||||||
|
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
||||||
|
"attn.c_proj": ["attn1.to_out.0"],
|
||||||
|
"ln_1": ["norm1"],
|
||||||
|
"ln_2": ["norm3"],
|
||||||
|
"mlp.c_fc": ["ff.net.0.proj"],
|
||||||
|
"mlp.c_proj": ["ff.net.2"],
|
||||||
|
}
|
||||||
|
|
||||||
|
mapped_weights = {}
|
||||||
|
for key, value in keys_mapper.items():
|
||||||
|
prefix = key[: len("blocks.i")]
|
||||||
|
suffix = key.split(prefix)[-1].split(".")[-1]
|
||||||
|
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
||||||
|
mapped_names = MAPPING[name]
|
||||||
|
|
||||||
|
num_splits = len(mapped_names)
|
||||||
|
for i, mapped_name in enumerate(mapped_names):
|
||||||
|
new_name = ".".join([prefix, mapped_name, suffix])
|
||||||
|
shape = value.shape[0] // num_splits
|
||||||
|
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
||||||
|
|
||||||
|
model.mapper.load_state_dict(mapped_weights)
|
||||||
|
|
||||||
|
# load final layer norm
|
||||||
|
model.final_layer_norm.load_state_dict(
|
||||||
|
{
|
||||||
|
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
||||||
|
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# load final proj
|
||||||
|
model.proj_out.load_state_dict(
|
||||||
|
{
|
||||||
|
"bias": checkpoint["proj_out.bias"],
|
||||||
|
"weight": checkpoint["proj_out.weight"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# load uncond vector
|
||||||
|
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
|
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
text_model_dict = {}
|
||||||
|
|
||||||
|
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||||
|
|
||||||
|
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||||
|
continue
|
||||||
|
if key in textenc_conversion_map:
|
||||||
|
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
||||||
|
if key.startswith("cond_stage_model.model.transformer."):
|
||||||
|
new_key = key[len("cond_stage_model.model.transformer.") :]
|
||||||
|
if new_key.endswith(".in_proj_weight"):
|
||||||
|
new_key = new_key[: -len(".in_proj_weight")]
|
||||||
|
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||||
|
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||||
|
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
||||||
|
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
||||||
|
elif new_key.endswith(".in_proj_bias"):
|
||||||
|
new_key = new_key[: -len(".in_proj_bias")]
|
||||||
|
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||||
|
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
||||||
|
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
||||||
|
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
||||||
|
else:
|
||||||
|
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||||
|
|
||||||
|
text_model_dict[new_key] = checkpoint[key]
|
||||||
|
|
||||||
|
text_model.load_state_dict(text_model_dict)
|
||||||
|
|
||||||
|
return text_model
|
||||||
|
|
||||||
|
def convert_ckpt_to_diffuser(checkpoint_path:str,
|
||||||
|
dump_path:str,
|
||||||
|
original_config_file:str=None,
|
||||||
|
num_in_channels:int=None,
|
||||||
|
scheduler_type:str='pndm',
|
||||||
|
pipeline_type:str=None,
|
||||||
|
image_size:int=None,
|
||||||
|
prediction_type:str=None,
|
||||||
|
extract_ema:bool=False,
|
||||||
|
upcast_attn:bool=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
checkpoint = load_file(checkpoint_path) if Path(checkpoint_path).suffix == '.safetensors' else torch.load(checkpoint_path)
|
||||||
|
|
||||||
|
# Sometimes models don't have the global_step item
|
||||||
|
if "global_step" in checkpoint:
|
||||||
|
global_step = checkpoint["global_step"]
|
||||||
|
else:
|
||||||
|
print("global_step key not found in model")
|
||||||
|
global_step = None
|
||||||
|
|
||||||
|
# sometimes there is a state_dict key and sometimes not
|
||||||
|
if 'state_dict' in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
upcast_attention = False
|
||||||
|
if original_config_file is None:
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
|
||||||
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
|
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v2-inference-v.yaml')
|
||||||
|
|
||||||
|
if global_step == 110000:
|
||||||
|
# v2.1 needs to upcast attention
|
||||||
|
upcast_attention = True
|
||||||
|
else:
|
||||||
|
original_config_file = os.path.join(Globals.root,'configs','stable-diffusion','v1-inference.yaml')
|
||||||
|
|
||||||
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
|
if num_in_channels is not None:
|
||||||
|
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||||
|
|
||||||
|
if (
|
||||||
|
"parameterization" in original_config["model"]["params"]
|
||||||
|
and original_config["model"]["params"]["parameterization"] == "v"
|
||||||
|
):
|
||||||
|
if prediction_type is None:
|
||||||
|
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||||
|
if image_size is None:
|
||||||
|
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
image_size = 512 if global_step == 875000 else 768
|
||||||
|
else:
|
||||||
|
if prediction_type is None:
|
||||||
|
prediction_type = "epsilon"
|
||||||
|
if image_size is None:
|
||||||
|
image_size = 512
|
||||||
|
|
||||||
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
|
beta_start = original_config.model.params.linear_start
|
||||||
|
beta_end = original_config.model.params.linear_end
|
||||||
|
|
||||||
|
scheduler = DDIMScheduler(
|
||||||
|
beta_end=beta_end,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
beta_start=beta_start,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
steps_offset=1,
|
||||||
|
clip_sample=False,
|
||||||
|
set_alpha_to_one=False,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
)
|
||||||
|
# make sure scheduler works correctly with DDIM
|
||||||
|
scheduler.register_to_config(clip_sample=False)
|
||||||
|
|
||||||
|
if scheduler_type == "pndm":
|
||||||
|
config = dict(scheduler.config)
|
||||||
|
config["skip_prk_steps"] = True
|
||||||
|
scheduler = PNDMScheduler.from_config(config)
|
||||||
|
elif scheduler_type == "lms":
|
||||||
|
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "heun":
|
||||||
|
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "euler":
|
||||||
|
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "euler-ancestral":
|
||||||
|
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "dpm":
|
||||||
|
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == "ddim":
|
||||||
|
scheduler = scheduler
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||||
|
|
||||||
|
# Convert the UNet2DConditionModel model.
|
||||||
|
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||||
|
unet_config["upcast_attention"] = upcast_attention
|
||||||
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
|
|
||||||
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||||
|
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||||
|
)
|
||||||
|
|
||||||
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
|
# Convert the VAE model.
|
||||||
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|
||||||
|
# Convert the text model.
|
||||||
|
model_type = pipeline_type
|
||||||
|
if model_type is None:
|
||||||
|
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
|
|
||||||
|
if model_type == "FrozenOpenCLIPEmbedder":
|
||||||
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||||
|
pipe = StableDiffusionPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
elif model_type == "PaintByExample":
|
||||||
|
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
pipe = PaintByExamplePipeline(
|
||||||
|
vae=vae,
|
||||||
|
image_encoder=vision_model,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||||
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
pipe = StableDiffusionPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=safety_checker,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_config = create_ldm_bert_config(original_config)
|
||||||
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
|
pipe.save_pretrained(
|
||||||
|
dump_path,
|
||||||
|
safe_serialization=1,
|
||||||
|
)
|
@ -12,7 +12,7 @@ from urllib import request, error as ul_error
|
|||||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
class Concepts(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
'''
|
'''
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
Initialize the Concepts object. May optionally pass a root directory.
|
||||||
|
@ -16,9 +16,15 @@ from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
|||||||
from ..models.diffusion import cross_attention_control
|
from ..models.diffusion import cross_attention_control
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
|
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
|
|
||||||
|
# lazy-load any deferred textual inversions.
|
||||||
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
|
|
||||||
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
||||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
||||||
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
||||||
@ -216,7 +222,7 @@ def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
|||||||
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
||||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||||
(embeddings_to_blend, this_embedding))
|
(embeddings_to_blend, this_embedding))
|
||||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||||
blend.weights,
|
blend.weights,
|
||||||
normalize=blend.normalize_weights)
|
normalize=blend.normalize_weights)
|
||||||
return conditioning
|
return conditioning
|
||||||
@ -238,7 +244,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
|
|||||||
|
|
||||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
def _get_tokens_length(model, fragments: list[Fragment]):
|
||||||
fragment_texts = [x.text for x in fragments]
|
fragment_texts = [x.text for x in fragments]
|
||||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
|
||||||
return sum([len(x) for x in tokens])
|
return sum([len(x) for x in tokens])
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
def choose_torch_device() -> str:
|
def choose_torch_device() -> str:
|
||||||
'''Convenience routine for guessing which GPU device to run model on'''
|
'''Convenience routine for guessing which GPU device to run model on'''
|
||||||
|
if Globals.always_use_cpu:
|
||||||
|
return "cpu"
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return 'cuda'
|
return 'cuda'
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
|
@ -2,26 +2,37 @@
|
|||||||
Base class for ldm.invoke.generator.*
|
Base class for ldm.invoke.generator.*
|
||||||
including img2img, txt2img, and inpaint
|
including img2img, txt2img, and inpaint
|
||||||
'''
|
'''
|
||||||
import torch
|
from __future__ import annotations
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from PIL import Image, ImageFilter, ImageChops
|
from PIL import Image, ImageFilter, ImageChops
|
||||||
import cv2 as cv
|
from diffusers import DiffusionPipeline
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
CAUTION_IMG = 'assets/caution.png'
|
CAUTION_IMG = 'assets/caution.png'
|
||||||
|
|
||||||
class Generator():
|
class Generator:
|
||||||
def __init__(self, model, precision):
|
downsampling_factor: int
|
||||||
|
latent_channels: int
|
||||||
|
precision: str
|
||||||
|
model: DiffusionWrapper | DiffusionPipeline
|
||||||
|
|
||||||
|
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
@ -52,7 +63,6 @@ class Generator():
|
|||||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
attention_maps_callback = None,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
self.safety_checker = safety_checker
|
self.safety_checker = safety_checker
|
||||||
@ -165,7 +175,7 @@ class Generator():
|
|||||||
# Blur the mask out (into init image) by specified amount
|
# Blur the mask out (into init image) by specified amount
|
||||||
if mask_blur_radius > 0:
|
if mask_blur_radius > 0:
|
||||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
nmd = cv2.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||||
pmd = Image.fromarray(nmd, mode='L')
|
pmd = Image.fromarray(nmd, mode='L')
|
||||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||||
else:
|
else:
|
||||||
@ -177,8 +187,6 @@ class Generator():
|
|||||||
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
matched_result.paste(init_image, (0,0), mask = multiplied_blurred_init_mask)
|
||||||
return matched_result
|
return matched_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(self,samples):
|
def sample_to_lowres_estimated_image(self,samples):
|
||||||
# origingally adapted from code by @erucipe and @keturn here:
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
632
ldm/invoke/generator/diffusers_pipeline.py
Normal file
632
ldm/invoke/generator/diffusers_pipeline.py
Normal file
@ -0,0 +1,632 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import inspect
|
||||||
|
import secrets
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
||||||
|
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
else:
|
||||||
|
from typing import ParamSpec
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from diffusers.models import attention
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
from ...models.diffusion import cross_attention_control
|
||||||
|
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||||
|
|
||||||
|
# monkeypatch diffusers CrossAttention 🙈
|
||||||
|
# this is to make prompt2prompt and (future) attention maps work
|
||||||
|
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
|
||||||
|
|
||||||
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
from diffusers.utils.outputs import BaseOutput
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
|
||||||
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineIntermediateState:
|
||||||
|
run_id: str
|
||||||
|
step: int
|
||||||
|
timestep: int
|
||||||
|
latents: torch.Tensor
|
||||||
|
predicted_original: Optional[torch.Tensor] = None
|
||||||
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
|
|
||||||
|
|
||||||
|
# copied from configs/stable-diffusion/v1-inference.yaml
|
||||||
|
_default_personalization_config_params = dict(
|
||||||
|
placeholder_strings=["*"],
|
||||||
|
initializer_wods=["sculpture"],
|
||||||
|
per_image_tokens=False,
|
||||||
|
num_vectors_per_token=1,
|
||||||
|
progressive_words=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AddsMaskLatents:
|
||||||
|
"""Add the channels required for inpainting model input.
|
||||||
|
|
||||||
|
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||||
|
and the latent encoding of the base image.
|
||||||
|
|
||||||
|
This class assumes the same mask and base image should apply to all items in the batch.
|
||||||
|
"""
|
||||||
|
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
|
mask: torch.Tensor
|
||||||
|
initial_image_latents: torch.Tensor
|
||||||
|
|
||||||
|
def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
|
||||||
|
model_input = self.add_mask_channels(latents)
|
||||||
|
return self.forward(model_input, t, text_embeddings)
|
||||||
|
|
||||||
|
def add_mask_channels(self, latents):
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
# duplicate mask and latents for each batch
|
||||||
|
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
# add mask and image as additional channels
|
||||||
|
model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
|
||||||
|
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||||
|
return (
|
||||||
|
isinstance(b, torch.Tensor)
|
||||||
|
and (a.size() == b.size())
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AddsMaskGuidance:
|
||||||
|
mask: torch.FloatTensor
|
||||||
|
mask_latents: torch.FloatTensor
|
||||||
|
scheduler: SchedulerMixin
|
||||||
|
noise: torch.Tensor
|
||||||
|
_debug: Optional[Callable] = None
|
||||||
|
|
||||||
|
def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
|
||||||
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
|
|
||||||
|
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||||
|
# It's reasonable to assume the first thing is prev_sample, but then does it have other things
|
||||||
|
# like pred_original_sample? Should we apply the mask to them too?
|
||||||
|
# But what if there's just some other random field?
|
||||||
|
prev_sample = step_output[0]
|
||||||
|
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||||
|
return output_class(
|
||||||
|
{k: (self.apply_mask(v, self._t_for_field(k, t))
|
||||||
|
if are_like_tensors(prev_sample, v) else v)
|
||||||
|
for k, v in step_output.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _t_for_field(self, field_name:str, t):
|
||||||
|
if field_name == "pred_original_sample":
|
||||||
|
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0
|
||||||
|
return t
|
||||||
|
|
||||||
|
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
if t.dim() == 0:
|
||||||
|
# some schedulers expect t to be one-dimensional.
|
||||||
|
# TODO: file diffusers bug about inconsistency?
|
||||||
|
t = einops.repeat(t, '-> batch', batch=batch_size)
|
||||||
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
|
# get very confused about what is happening from step to step when we do that.
|
||||||
|
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
|
mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
|
||||||
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
|
if self._debug:
|
||||||
|
self._debug(masked_input, f"t={t} lerped")
|
||||||
|
return masked_input
|
||||||
|
|
||||||
|
|
||||||
|
def trim_to_multiple_of(*args, multiple_of=8):
|
||||||
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
|
|
||||||
|
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param image: input image
|
||||||
|
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||||
|
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||||
|
"""
|
||||||
|
w, h = trim_to_multiple_of(*image.size)
|
||||||
|
transformation = T.Compose([
|
||||||
|
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||||
|
T.ToTensor(),
|
||||||
|
])
|
||||||
|
tensor = transformation(image)
|
||||||
|
if normalize:
|
||||||
|
tensor = tensor * 2.0 - 1.0
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def is_inpainting_model(unet: UNet2DConditionModel):
|
||||||
|
return unet.conv_in.in_channels == 9
|
||||||
|
|
||||||
|
CallbackType = TypeVar('CallbackType')
|
||||||
|
ReturnType = TypeVar('ReturnType')
|
||||||
|
ParamType = ParamSpec('ParamType')
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||||
|
"""Convert a generator to a function with a callback and a return value."""
|
||||||
|
|
||||||
|
generator_method: Callable[ParamType, ReturnType]
|
||||||
|
callback_arg_type: Type[CallbackType]
|
||||||
|
|
||||||
|
def __call__(self, *args: ParamType.args,
|
||||||
|
callback:Callable[[CallbackType], Any]=None,
|
||||||
|
**kwargs: ParamType.kwargs) -> ReturnType:
|
||||||
|
result = None
|
||||||
|
for result in self.generator_method(*args, **kwargs):
|
||||||
|
if callback is not None and isinstance(result, self.callback_arg_type):
|
||||||
|
callback(result)
|
||||||
|
if result is None:
|
||||||
|
raise AssertionError("why was that an empty generator?")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConditioningData:
|
||||||
|
unconditioned_embeddings: torch.Tensor
|
||||||
|
text_embeddings: torch.Tensor
|
||||||
|
guidance_scale: float
|
||||||
|
"""
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
"""
|
||||||
|
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||||
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Additional arguments to pass to scheduler.step."""
|
||||||
|
threshold: Optional[ThresholdSettings] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.text_embeddings.dtype
|
||||||
|
|
||||||
|
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||||
|
scheduler_args = dict(self.scheduler_args)
|
||||||
|
step_method = inspect.signature(scheduler.step)
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
try:
|
||||||
|
step_method.bind_partial(**{name: value})
|
||||||
|
except TypeError:
|
||||||
|
# FIXME: don't silently discard arguments
|
||||||
|
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||||
|
else:
|
||||||
|
scheduler_args[name] = value
|
||||||
|
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||||
|
r"""
|
||||||
|
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||||
|
after generation completes. Optional.
|
||||||
|
"""
|
||||||
|
attention_map_saver: Optional[AttentionMapSaver]
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline for text-to-image generation using Stable Diffusion.
|
||||||
|
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||||
|
|
||||||
|
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
|
||||||
|
Hopefully future versions of diffusers provide access to more of these functions so that we don't
|
||||||
|
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vae ([`AutoencoderKL`]):
|
||||||
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||||
|
text_encoder ([`CLIPTextModel`]):
|
||||||
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||||
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||||
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||||
|
tokenizer (`CLIPTokenizer`):
|
||||||
|
Tokenizer of class
|
||||||
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||||
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||||
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||||
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||||
|
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||||
|
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||||
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||||
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ID_LENGTH = 8
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||||
|
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||||
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
|
requires_safety_checker: bool = False,
|
||||||
|
precision: str = 'float32',
|
||||||
|
):
|
||||||
|
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
||||||
|
safety_checker, feature_extractor, requires_safety_checker)
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=safety_checker,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||||
|
use_full_precision = (precision == 'float32' or precision == 'autocast')
|
||||||
|
self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
full_precision=use_full_precision)
|
||||||
|
# InvokeAI's interface for text embeddings and whatnot
|
||||||
|
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
textual_inversion_manager=self.textual_inversion_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_xformers_available():
|
||||||
|
self.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
*,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
callback: Callable[[PipelineIntermediateState], None]=None,
|
||||||
|
run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
:param conditioning_data:
|
||||||
|
:param latents: Pre-generated un-noised latents, to be used as inputs for
|
||||||
|
image generation. Can be used to tweak the same generation with different prompts.
|
||||||
|
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
|
||||||
|
image at the expense of slower inference.
|
||||||
|
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
|
||||||
|
:param callback:
|
||||||
|
:param run_id:
|
||||||
|
"""
|
||||||
|
result_latents, result_attention_map_saver = self.latents_from_embeddings(
|
||||||
|
latents, num_inference_steps,
|
||||||
|
conditioning_data,
|
||||||
|
noise=noise,
|
||||||
|
run_id=run_id,
|
||||||
|
callback=callback)
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
image = self.decode_latents(result_latents)
|
||||||
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
|
||||||
|
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||||
|
|
||||||
|
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
*,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
timesteps=None,
|
||||||
|
additional_guidance: List[Callable] = None, run_id=None,
|
||||||
|
callback: Callable[[PipelineIntermediateState], None] = None
|
||||||
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
|
if timesteps is None:
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||||
|
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||||
|
latents, timesteps, conditioning_data,
|
||||||
|
noise=noise,
|
||||||
|
additional_guidance=additional_guidance,
|
||||||
|
run_id=run_id,
|
||||||
|
callback=callback)
|
||||||
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
|
def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
*,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
run_id: str = None,
|
||||||
|
additional_guidance: List[Callable] = None):
|
||||||
|
if run_id is None:
|
||||||
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
|
if additional_guidance is None:
|
||||||
|
additional_guidance = []
|
||||||
|
extra_conditioning_info = conditioning_data.extra
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||||
|
step_count=len(self.scheduler.timesteps))
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
|
latents=latents)
|
||||||
|
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
batched_t = torch.full((batch_size,), timesteps[0],
|
||||||
|
dtype=timesteps.dtype, device=self.unet.device)
|
||||||
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
|
self.invokeai_diffuser.remove_attention_map_saving()
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
batched_t.fill_(t)
|
||||||
|
step_output = self.step(batched_t, latents, conditioning_data,
|
||||||
|
i, additional_guidance=additional_guidance)
|
||||||
|
latents = step_output.prev_sample
|
||||||
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
|
|
||||||
|
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||||
|
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||||
|
attention_map_token_ids = range(1, eos_token_index)
|
||||||
|
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||||
|
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
|
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||||
|
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||||
|
|
||||||
|
self.invokeai_diffuser.remove_attention_map_saving()
|
||||||
|
return latents, attention_map_saver
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
step_index:int | None = None, additional_guidance: List[Callable] = None):
|
||||||
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
|
timestep = t[0]
|
||||||
|
|
||||||
|
if additional_guidance is None:
|
||||||
|
additional_guidance = []
|
||||||
|
|
||||||
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
|
latent_model_input, t,
|
||||||
|
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||||
|
conditioning_data.guidance_scale,
|
||||||
|
step_index=step_index,
|
||||||
|
threshold=conditioning_data.threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
step_output = self.scheduler.step(noise_pred, timestep, latents,
|
||||||
|
**conditioning_data.scheduler_args)
|
||||||
|
|
||||||
|
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||||
|
# But the way things are now, scheduler runs _after_ that, so there was
|
||||||
|
# no way to use it to apply an operation that happens after the last scheduler.step.
|
||||||
|
for guidance in additional_guidance:
|
||||||
|
step_output = guidance(step_output, timestep, conditioning_data)
|
||||||
|
|
||||||
|
return step_output
|
||||||
|
|
||||||
|
def _unet_forward(self, latents, t, text_embeddings):
|
||||||
|
"""predict the noise residual"""
|
||||||
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||||
|
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||||
|
# FIXME: There are too many layers of functions and we have too many different ways of
|
||||||
|
# overriding things! This should get handled in a way more consistent with the other
|
||||||
|
# use of AddsMaskLatents.
|
||||||
|
latents = AddsMaskLatents(
|
||||||
|
self._unet_forward,
|
||||||
|
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||||
|
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||||
|
).add_mask_channels(latents)
|
||||||
|
|
||||||
|
return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
|
||||||
|
|
||||||
|
def img2img_from_embeddings(self,
|
||||||
|
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||||
|
strength: float,
|
||||||
|
num_inference_steps: int,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
run_id=None,
|
||||||
|
noise_func=None
|
||||||
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||||
|
|
||||||
|
if init_image.dim() == 3:
|
||||||
|
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||||
|
|
||||||
|
# 6. Prepare latent variables
|
||||||
|
device = self.unet.device
|
||||||
|
latents_dtype = self.unet.dtype
|
||||||
|
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||||
|
noise = noise_func(initial_latents)
|
||||||
|
|
||||||
|
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
||||||
|
conditioning_data,
|
||||||
|
strength,
|
||||||
|
noise, run_id, callback)
|
||||||
|
|
||||||
|
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
strength,
|
||||||
|
noise: torch.Tensor, run_id=None, callback=None
|
||||||
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
|
device = self.unet.device
|
||||||
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||||
|
|
||||||
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
|
initial_latents, num_inference_steps, conditioning_data,
|
||||||
|
timesteps=timesteps,
|
||||||
|
noise=noise,
|
||||||
|
run_id=run_id,
|
||||||
|
callback=callback)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
image = self.decode_latents(result_latents)
|
||||||
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||||
|
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||||
|
|
||||||
|
def inpaint_from_embeddings(
|
||||||
|
self,
|
||||||
|
init_image: torch.FloatTensor,
|
||||||
|
mask: torch.FloatTensor,
|
||||||
|
strength: float,
|
||||||
|
num_inference_steps: int,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
*, callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
run_id=None,
|
||||||
|
noise_func=None,
|
||||||
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
|
device = self.unet.device
|
||||||
|
latents_dtype = self.unet.dtype
|
||||||
|
|
||||||
|
if isinstance(init_image, PIL.Image.Image):
|
||||||
|
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
|
||||||
|
|
||||||
|
init_image = init_image.to(device=device, dtype=latents_dtype)
|
||||||
|
|
||||||
|
if init_image.dim() == 3:
|
||||||
|
init_image = init_image.unsqueeze(0)
|
||||||
|
|
||||||
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||||
|
|
||||||
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
|
|
||||||
|
# 6. Prepare latent variables
|
||||||
|
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||||
|
# because we have our own noise function
|
||||||
|
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||||
|
noise = noise_func(init_image_latents)
|
||||||
|
|
||||||
|
if mask.dim() == 3:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
|
||||||
|
.to(device=device, dtype=latents_dtype)
|
||||||
|
|
||||||
|
guidance: List[Callable] = []
|
||||||
|
|
||||||
|
if is_inpainting_model(self.unet):
|
||||||
|
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||||
|
self.invokeai_diffuser.model_forward_callback = \
|
||||||
|
AddsMaskLatents(self._unet_forward, mask, init_image_latents)
|
||||||
|
else:
|
||||||
|
guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise))
|
||||||
|
|
||||||
|
try:
|
||||||
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
|
init_image_latents, num_inference_steps,
|
||||||
|
conditioning_data, noise=noise, timesteps=timesteps,
|
||||||
|
additional_guidance=guidance,
|
||||||
|
run_id=run_id, callback=callback)
|
||||||
|
finally:
|
||||||
|
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
image = self.decode_latents(result_latents)
|
||||||
|
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
|
||||||
|
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||||
|
|
||||||
|
def non_noised_latents_from_image(self, init_image, *, device, dtype):
|
||||||
|
init_image = init_image.to(device=device, dtype=dtype)
|
||||||
|
with torch.inference_mode():
|
||||||
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
|
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
init_latents = 0.18215 * init_latents
|
||||||
|
return init_latents
|
||||||
|
|
||||||
|
def check_for_safety(self, output, dtype):
|
||||||
|
with torch.inference_mode():
|
||||||
|
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||||
|
output.images, device=self._execution_device, dtype=dtype)
|
||||||
|
screened_attention_map_saver = None
|
||||||
|
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||||
|
screened_attention_map_saver = output.attention_map_saver
|
||||||
|
return InvokeAIStableDiffusionPipelineOutput(screened_images,
|
||||||
|
has_nsfw_concept,
|
||||||
|
# block the attention maps if NSFW content is detected
|
||||||
|
attention_map_saver=screened_attention_map_saver)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||||
|
"""
|
||||||
|
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||||
|
"""
|
||||||
|
return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments(
|
||||||
|
text=c,
|
||||||
|
fragment_weights=fragment_weights,
|
||||||
|
should_return_tokens=return_tokens,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cond_stage_model(self):
|
||||||
|
warnings.warn("legacy compatibility layer", DeprecationWarning)
|
||||||
|
return self.prompt_fragments_to_embeddings_converter
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||||
|
return self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self) -> int:
|
||||||
|
"""Compatible with DiffusionWrapper"""
|
||||||
|
return self.unet.in_channels
|
||||||
|
|
||||||
|
def debug_latents(self, latents, msg):
|
||||||
|
with torch.inference_mode():
|
||||||
|
from ldm.util import debug_image
|
||||||
|
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||||
|
for i, img in enumerate(decoded):
|
||||||
|
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
@ -3,14 +3,14 @@ ldm.invoke.generator.embiggen descends from ldm.invoke.generator
|
|||||||
and generates with ldm.invoke.generator.img2img
|
and generates with ldm.invoke.generator.img2img
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import trange
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
from ldm.invoke.devices import choose_autocast
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
|
|
||||||
class Embiggen(Generator):
|
class Embiggen(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -22,7 +22,6 @@ class Embiggen(Generator):
|
|||||||
image_callback=None, step_callback=None,
|
image_callback=None, step_callback=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
scope = choose_autocast(self.precision)
|
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
step_callback = step_callback,
|
step_callback = step_callback,
|
||||||
@ -32,8 +31,7 @@ class Embiggen(Generator):
|
|||||||
seed = seed if seed else self.new_seed()
|
seed = seed if seed else self.new_seed()
|
||||||
|
|
||||||
# Noise will be generated by the Img2Img generator when called
|
# Noise will be generated by the Img2Img generator when called
|
||||||
with scope(self.model.device.type), self.model.ema_scope():
|
for _ in trange(iterations, desc='Generating'):
|
||||||
for n in trange(iterations, desc='Generating'):
|
|
||||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||||
image = make_image()
|
image = make_image()
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
@ -353,7 +351,7 @@ class Embiggen(Generator):
|
|||||||
prompt,
|
prompt,
|
||||||
iterations = 1,
|
iterations = 1,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
sampler = DDIMSampler(self.model, device=self.model.device),
|
sampler = sampler,
|
||||||
steps = steps,
|
steps = steps,
|
||||||
cfg_scale = cfg_scale,
|
cfg_scale = cfg_scale,
|
||||||
conditioning = conditioning,
|
conditioning = conditioning,
|
||||||
@ -493,7 +491,7 @@ class Embiggen(Generator):
|
|||||||
# Layer tile onto final image
|
# Layer tile onto final image
|
||||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||||
else:
|
else:
|
||||||
print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
print('Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')
|
||||||
|
|
||||||
# after internal loops and patching up return Embiggen image
|
# after internal loops and patching up return Embiggen image
|
||||||
return outputsuperimage
|
return outputsuperimage
|
||||||
|
@ -3,14 +3,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
from diffusers import logging
|
||||||
import PIL
|
|
||||||
from torch import Tensor
|
|
||||||
from PIL import Image
|
|
||||||
from ldm.invoke.devices import choose_autocast
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||||
|
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -18,80 +16,69 @@ class Img2Img(Generator):
|
|||||||
self.init_latent = None # by get_noise()
|
self.init_latent = None # by get_noise()
|
||||||
|
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
|
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
|
||||||
|
attention_maps_callback=None,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it.
|
Return value depends on the seed at the time you call it.
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
|
||||||
sampler.make_schedule(
|
# noinspection PyTypeChecker
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
)
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
|
||||||
init_image = self._image_to_tensor(init_image.convert('RGB'))
|
|
||||||
|
|
||||||
scope = choose_autocast(self.precision)
|
|
||||||
with scope(self.model.device.type):
|
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
|
||||||
self.model.encode_first_stage(init_image)
|
|
||||||
) # move to latent space
|
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
conditioning_data = (
|
||||||
|
ConditioningData(
|
||||||
|
uc, c, cfg_scale, extra_conditioning_info,
|
||||||
|
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||||
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
# encode (scaled latent)
|
# FIXME: use x_T for initial seeded noise
|
||||||
z_enc = sampler.stochastic_encode(
|
# We're not at the moment because the pipeline automatically resizes init_image if
|
||||||
self.init_latent,
|
# necessary, which the x_T input might not match.
|
||||||
torch.tensor([t_enc - 1]).to(self.model.device),
|
logging.set_verbosity_error() # quench safety check warnings
|
||||||
noise=x_T
|
pipeline_output = pipeline.img2img_from_embeddings(
|
||||||
|
init_image, strength, steps, conditioning_data,
|
||||||
|
noise_func=self.get_noise_like,
|
||||||
|
callback=step_callback
|
||||||
)
|
)
|
||||||
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
self.model.model.to(self.model.device)
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
# decode it
|
|
||||||
samples = sampler.decode(
|
|
||||||
z_enc,
|
|
||||||
c,
|
|
||||||
t_enc,
|
|
||||||
img_callback = step_callback,
|
|
||||||
unconditional_guidance_scale=cfg_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
|
||||||
extra_conditioning_info = extra_conditioning_info,
|
|
||||||
all_timesteps_count = steps
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.free_gpu_mem:
|
|
||||||
self.model.model.to("cpu")
|
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
def get_noise(self,width,height):
|
def get_noise_like(self, like: torch.Tensor):
|
||||||
device = self.model.device
|
device = like.device
|
||||||
init_latent = self.init_latent
|
|
||||||
assert init_latent is not None,'call to get_noise() when init_latent not set'
|
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
x = torch.randn_like(init_latent, device='cpu').to(device)
|
x = torch.randn_like(like, device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
x = torch.randn_like(init_latent, device=device)
|
x = torch.randn_like(like, device=device)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
shape = init_latent.shape
|
shape = like.shape
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
def get_noise(self,width,height):
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
# copy of the Txt2Img.get_noise
|
||||||
if len(image.shape) == 2: # 'L' image, as in a mask
|
device = self.model.device
|
||||||
image = image[None,None]
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
else: # 'RGB' image
|
x = torch.randn([1,
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
self.latent_channels,
|
||||||
image = torch.from_numpy(image)
|
height // self.downsampling_factor,
|
||||||
if normalize:
|
width // self.downsampling_factor],
|
||||||
image = 2.0 * image - 1.0
|
device='cpu').to(device)
|
||||||
return image.to(self.model.device)
|
else:
|
||||||
|
x = torch.randn([1,
|
||||||
|
self.latent_channels,
|
||||||
|
height // self.downsampling_factor,
|
||||||
|
width // self.downsampling_factor],
|
||||||
|
device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
|
return x
|
||||||
|
@ -1,24 +1,22 @@
|
|||||||
'''
|
'''
|
||||||
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||||
'''
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
|
||||||
import torchvision.transforms as T
|
|
||||||
import numpy as np
|
|
||||||
import cv2 as cv
|
|
||||||
import PIL
|
import PIL
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||||
from skimage.exposure.histogram_matching import match_histograms
|
|
||||||
from einops import rearrange, repeat
|
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \
|
||||||
from ldm.invoke.devices import choose_autocast
|
ConditioningData
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
|
||||||
from ldm.invoke.generator.base import downsampling
|
|
||||||
from ldm.util import debug_image
|
|
||||||
from ldm.invoke.patchmatch import PatchMatch
|
from ldm.invoke.patchmatch import PatchMatch
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.util import debug_image
|
||||||
|
|
||||||
|
|
||||||
def infill_methods()->list[str]:
|
def infill_methods()->list[str]:
|
||||||
methods = list()
|
methods = list()
|
||||||
@ -29,6 +27,9 @@ def infill_methods()->list[str]:
|
|||||||
|
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
|
self.inpaint_height = 0
|
||||||
|
self.inpaint_width = 0
|
||||||
|
self.enable_image_debugging = False
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.pil_image = None
|
self.pil_image = None
|
||||||
self.pil_mask = None
|
self.pil_mask = None
|
||||||
@ -117,13 +118,13 @@ class Inpaint(Img2Img):
|
|||||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||||
|
|
||||||
# Detect hard edges
|
# Detect hard edges
|
||||||
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)
|
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||||
|
|
||||||
# Combine
|
# Combine
|
||||||
npmask = npgradient + npedge
|
npmask = npgradient + npedge
|
||||||
|
|
||||||
# Expand
|
# Expand
|
||||||
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
npmask = cv2.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))
|
||||||
|
|
||||||
new_mask = Image.fromarray(npmask)
|
new_mask = Image.fromarray(npmask)
|
||||||
|
|
||||||
@ -133,15 +134,8 @@ class Inpaint(Img2Img):
|
|||||||
return ImageOps.invert(new_mask)
|
return ImageOps.invert(new_mask)
|
||||||
|
|
||||||
|
|
||||||
def seam_paint(self,
|
def seam_paint(self, im: Image.Image, seam_size: int, seam_blur: int, prompt, sampler, steps, cfg_scale, ddim_eta,
|
||||||
im: Image.Image,
|
conditioning, strength, noise, infill_method, step_callback) -> Image.Image:
|
||||||
seam_size: int,
|
|
||||||
seam_blur: int,
|
|
||||||
prompt,sampler,steps,cfg_scale,ddim_eta,
|
|
||||||
conditioning,strength,
|
|
||||||
noise,
|
|
||||||
step_callback
|
|
||||||
) -> Image.Image:
|
|
||||||
hard_mask = self.pil_image.split()[-1].copy()
|
hard_mask = self.pil_image.split()[-1].copy()
|
||||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||||
|
|
||||||
@ -153,13 +147,14 @@ class Inpaint(Img2Img):
|
|||||||
ddim_eta,
|
ddim_eta,
|
||||||
conditioning,
|
conditioning,
|
||||||
init_image = im.copy().convert('RGBA'),
|
init_image = im.copy().convert('RGBA'),
|
||||||
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
|
mask_image = mask,
|
||||||
strength = strength,
|
strength = strength,
|
||||||
mask_blur_radius = 0,
|
mask_blur_radius = 0,
|
||||||
seam_size = 0,
|
seam_size = 0,
|
||||||
step_callback = step_callback,
|
step_callback = step_callback,
|
||||||
inpaint_width = im.width,
|
inpaint_width = im.width,
|
||||||
inpaint_height = im.height
|
inpaint_height = im.height,
|
||||||
|
infill_method = infill_method
|
||||||
)
|
)
|
||||||
|
|
||||||
seam_noise = self.get_noise(im.width, im.height)
|
seam_noise = self.get_noise(im.width, im.height)
|
||||||
@ -171,7 +166,10 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
conditioning,init_image,mask_image,strength,
|
conditioning,
|
||||||
|
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||||
|
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||||
|
strength: float,
|
||||||
mask_blur_radius: int = 8,
|
mask_blur_radius: int = 8,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 0,
|
||||||
@ -184,6 +182,7 @@ class Inpaint(Img2Img):
|
|||||||
infill_method = None,
|
infill_method = None,
|
||||||
inpaint_width=None,
|
inpaint_width=None,
|
||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
|
attention_maps_callback=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and
|
Returns a function returning an image derived from the prompt and
|
||||||
@ -218,13 +217,17 @@ class Inpaint(Img2Img):
|
|||||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
# Create init tensor
|
# Create init tensor
|
||||||
init_image = self._image_to_tensor(init_filled.convert('RGB'))
|
init_image = image_resized_to_grid_as_tensor(init_filled.convert('RGB'))
|
||||||
|
|
||||||
if isinstance(mask_image, PIL.Image.Image):
|
if isinstance(mask_image, PIL.Image.Image):
|
||||||
self.pil_mask = mask_image.copy()
|
self.pil_mask = mask_image.copy()
|
||||||
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
debug_image(mask_image, "mask_image BEFORE multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
mask_image = ImageChops.multiply(mask_image, self.pil_image.split()[-1].convert('RGB'))
|
init_alpha = self.pil_image.getchannel("A")
|
||||||
|
if mask_image.mode != "L":
|
||||||
|
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||||
|
mask_image = mask_image.convert("L")
|
||||||
|
mask_image = ImageChops.multiply(mask_image, init_alpha)
|
||||||
self.pil_mask = mask_image
|
self.pil_mask = mask_image
|
||||||
|
|
||||||
# Resize if requested for inpainting
|
# Resize if requested for inpainting
|
||||||
@ -232,95 +235,45 @@ class Inpaint(Img2Img):
|
|||||||
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
mask_image = mask_image.resize((inpaint_width, inpaint_height))
|
||||||
|
|
||||||
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
debug_image(mask_image, "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging)
|
||||||
mask_image = mask_image.resize(
|
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
(
|
else:
|
||||||
mask_image.width // downsampling,
|
mask: torch.FloatTensor = mask_image
|
||||||
mask_image.height // downsampling
|
|
||||||
),
|
|
||||||
resample=Image.Resampling.NEAREST
|
|
||||||
)
|
|
||||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
|
||||||
|
|
||||||
self.mask_blur_radius = mask_blur_radius
|
self.mask_blur_radius = mask_blur_radius
|
||||||
|
|
||||||
# klms samplers not supported yet, so ignore previous sampler
|
# noinspection PyTypeChecker
|
||||||
if isinstance(sampler,KSampler):
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
print(
|
pipeline.scheduler = sampler
|
||||||
f">> Using recommended DDIM sampler for inpainting."
|
|
||||||
)
|
|
||||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
|
||||||
|
|
||||||
sampler.make_schedule(
|
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
|
|
||||||
mask_image = repeat(mask_image, '1 ... -> b ...', b=1)
|
|
||||||
|
|
||||||
scope = choose_autocast(self.precision)
|
|
||||||
with scope(self.model.device.type):
|
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
|
||||||
self.model.encode_first_stage(init_image)
|
|
||||||
) # move to latent space
|
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
|
||||||
# todo: support cross-attention control
|
# todo: support cross-attention control
|
||||||
uc, c, _ = conditioning
|
uc, c, _ = conditioning
|
||||||
|
conditioning_data = (ConditioningData(uc, c, cfg_scale)
|
||||||
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
print(f">> target t_enc is {t_enc} steps")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
# encode (scaled latent)
|
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||||
z_enc = sampler.stochastic_encode(
|
init_image=init_image,
|
||||||
self.init_latent,
|
mask=1 - mask, # expects white means "paint here."
|
||||||
torch.tensor([t_enc - 1]).to(self.model.device),
|
strength=strength,
|
||||||
noise=x_T
|
num_inference_steps=steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
noise_func=self.get_noise_like,
|
||||||
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# to replace masked area with latent noise, weighted by inpaint_replace strength
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
if inpaint_replace > 0.0:
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
print(f'>> inpaint will replace what was under the mask with a strength of {inpaint_replace}')
|
|
||||||
l_noise = self.get_noise(kwargs['width'],kwargs['height'])
|
|
||||||
inverted_mask = 1.0-mask_image # there will be 1s where the mask is
|
|
||||||
masked_region = (1.0-inpaint_replace) * inverted_mask * z_enc + inpaint_replace * inverted_mask * l_noise
|
|
||||||
z_enc = z_enc * mask_image + masked_region
|
|
||||||
|
|
||||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
||||||
self.model.model.to(self.model.device)
|
|
||||||
|
|
||||||
# decode it
|
|
||||||
samples = sampler.decode(
|
|
||||||
z_enc,
|
|
||||||
c,
|
|
||||||
t_enc,
|
|
||||||
img_callback = step_callback,
|
|
||||||
unconditional_guidance_scale = cfg_scale,
|
|
||||||
unconditional_conditioning = uc,
|
|
||||||
mask = mask_image,
|
|
||||||
init_latent = self.init_latent
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.sample_to_image(samples)
|
|
||||||
|
|
||||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||||
if seam_size > 0:
|
if seam_size > 0:
|
||||||
old_image = self.pil_image or init_image
|
old_image = self.pil_image or init_image
|
||||||
old_mask = self.pil_mask or mask_image
|
old_mask = self.pil_mask or mask_image
|
||||||
|
|
||||||
result = self.seam_paint(
|
result = self.seam_paint(result, seam_size, seam_blur, prompt, sampler, seam_steps, cfg_scale, ddim_eta,
|
||||||
result,
|
conditioning, seam_strength, x_T, infill_method, step_callback)
|
||||||
seam_size,
|
|
||||||
seam_blur,
|
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
seam_steps,
|
|
||||||
cfg_scale,
|
|
||||||
ddim_eta,
|
|
||||||
conditioning,
|
|
||||||
seam_strength,
|
|
||||||
x_T,
|
|
||||||
step_callback)
|
|
||||||
|
|
||||||
# Restore original settings
|
# Restore original settings
|
||||||
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
self.get_make_image(prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
@ -343,6 +296,10 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
def sample_to_image(self, samples)->Image.Image:
|
def sample_to_image(self, samples)->Image.Image:
|
||||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||||
|
return self.postprocess_size_and_mask(gen_result)
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_size_and_mask(self, gen_result: Image.Image) -> Image.Image:
|
||||||
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
debug_image(gen_result, "gen_result", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
# Resize if necessary
|
# Resize if necessary
|
||||||
@ -352,7 +309,7 @@ class Inpaint(Img2Img):
|
|||||||
if self.pil_image is None or self.pil_mask is None:
|
if self.pil_image is None or self.pil_mask is None:
|
||||||
return gen_result
|
return gen_result
|
||||||
|
|
||||||
corrected_result = super().repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
corrected_result = self.repaste_and_color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)
|
||||||
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
debug_image(corrected_result, "corrected_result", debug_status=self.enable_image_debugging)
|
||||||
|
|
||||||
return corrected_result
|
return corrected_result
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
"""omnibus module to be used with the runwayml 9-channel custom inpainting model"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
from PIL import Image, ImageOps
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from PIL import Image, ImageOps, ImageChops
|
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.invoke.generator.base import downsampling
|
|
||||||
from ldm.invoke.generator.img2img import Img2Img
|
from ldm.invoke.generator.img2img import Img2Img
|
||||||
from ldm.invoke.generator.txt2img import Txt2Img
|
from ldm.invoke.generator.txt2img import Txt2Img
|
||||||
|
|
||||||
|
|
||||||
class Omnibus(Img2Img,Txt2Img):
|
class Omnibus(Img2Img,Txt2Img):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@ -40,6 +40,8 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
num_samples = 1
|
num_samples = 1
|
||||||
|
|
||||||
|
print('DEBUG: IN OMNIBUS')
|
||||||
|
|
||||||
sampler.make_schedule(
|
sampler.make_schedule(
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
)
|
)
|
||||||
@ -58,8 +60,6 @@ class Omnibus(Img2Img,Txt2Img):
|
|||||||
|
|
||||||
self.mask_blur_radius = mask_blur_radius
|
self.mask_blur_radius = mask_blur_radius
|
||||||
|
|
||||||
t_enc = steps
|
|
||||||
|
|
||||||
if init_image is not None and mask_image is not None: # inpainting
|
if init_image is not None and mask_image is not None: # inpainting
|
||||||
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
masked_image = init_image * (1 - mask_image) # masked image is the image masked by mask - masked regions zero
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
'''
|
'''
|
||||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||||
'''
|
'''
|
||||||
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from .base import Generator
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||||
import gc
|
from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
@ -24,45 +24,30 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
conditioning_data = (
|
||||||
|
ConditioningData(
|
||||||
|
uc, c, cfg_scale, extra_conditioning_info,
|
||||||
|
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||||
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_image(x_T):
|
|
||||||
shape = [
|
|
||||||
self.latent_channels,
|
|
||||||
height // self.downsampling_factor,
|
|
||||||
width // self.downsampling_factor,
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
def make_image(x_T) -> PIL.Image.Image:
|
||||||
self.model.model.to(self.model.device)
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
|
latents=torch.zeros_like(x_T),
|
||||||
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
|
noise=x_T,
|
||||||
|
num_inference_steps=steps,
|
||||||
samples, _ = sampler.sample(
|
conditioning_data=conditioning_data,
|
||||||
batch_size = 1,
|
callback=step_callback,
|
||||||
S = steps,
|
|
||||||
x_T = x_T,
|
|
||||||
conditioning = c,
|
|
||||||
shape = shape,
|
|
||||||
verbose = False,
|
|
||||||
unconditional_guidance_scale = cfg_scale,
|
|
||||||
unconditional_conditioning = uc,
|
|
||||||
extra_conditioning_info = extra_conditioning_info,
|
|
||||||
eta = ddim_eta,
|
|
||||||
img_callback = step_callback,
|
|
||||||
threshold = threshold,
|
|
||||||
attention_maps_callback = attention_maps_callback,
|
|
||||||
)
|
)
|
||||||
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
if self.free_gpu_mem:
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
self.model.model.to('cpu')
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
self.model.cond_stage_model.device = 'cpu'
|
|
||||||
self.model.cond_stage_model.to('cpu')
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
|
||||||
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
@ -70,15 +55,17 @@ class Txt2Img(Generator):
|
|||||||
# returns a tensor filled with random numbers from a normal distribution
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
def get_noise(self,width,height):
|
def get_noise(self,width,height):
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(self.latent_channels, 4)
|
||||||
if self.use_mps_noise or device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
input_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device='cpu').to(device)
|
device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
input_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device=device)
|
device=device)
|
||||||
|
@ -2,67 +2,55 @@
|
|||||||
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import math
|
import math
|
||||||
import gc
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||||
from ldm.invoke.generator.omnibus import Omnibus
|
ConditioningData
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
self.init_latent = None # for get_noise()
|
self.init_latent = None # for get_noise()
|
||||||
|
|
||||||
@torch.no_grad()
|
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
conditioning, width:int, height:int, strength:float,
|
||||||
conditioning,width,height,strength,step_callback=None,**kwargs):
|
step_callback:Optional[Callable]=None, threshold=0.0, **kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
|
conditioning_data = (
|
||||||
|
ConditioningData(
|
||||||
|
uc, c, cfg_scale, extra_conditioning_info,
|
||||||
|
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||||
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
scale_dim = min(width, height)
|
scale_dim = min(width, height)
|
||||||
scale = 512 / scale_dim
|
scale = 512 / scale_dim
|
||||||
|
|
||||||
init_width = math.ceil(scale * width / 64) * 64
|
init_width, init_height = trim_to_multiple_of(scale * width, scale * height)
|
||||||
init_height = math.ceil(scale * height / 64) * 64
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
|
|
||||||
shape = [
|
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||||
self.latent_channels,
|
latents=torch.zeros_like(x_T),
|
||||||
init_height // self.downsampling_factor,
|
num_inference_steps=steps,
|
||||||
init_width // self.downsampling_factor,
|
conditioning_data=conditioning_data,
|
||||||
]
|
noise=x_T,
|
||||||
|
callback=step_callback,
|
||||||
sampler.make_schedule(
|
# TODO: threshold = threshold,
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
#x = self.get_noise(init_width, init_height)
|
|
||||||
x = x_T
|
|
||||||
|
|
||||||
if self.free_gpu_mem and self.model.model.device != self.model.device:
|
|
||||||
self.model.model.to(self.model.device)
|
|
||||||
|
|
||||||
samples, _ = sampler.sample(
|
|
||||||
batch_size = 1,
|
|
||||||
S = steps,
|
|
||||||
x_T = x,
|
|
||||||
conditioning = c,
|
|
||||||
shape = shape,
|
|
||||||
verbose = False,
|
|
||||||
unconditional_guidance_scale = cfg_scale,
|
|
||||||
unconditional_conditioning = uc,
|
|
||||||
eta = ddim_eta,
|
|
||||||
img_callback = step_callback,
|
|
||||||
extra_conditioning_info = extra_conditioning_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@ -70,89 +58,46 @@ class Txt2Img2Img(Generator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
samples = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
samples,
|
first_pass_latent_output,
|
||||||
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
size=(height // self.downsampling_factor, width // self.downsampling_factor),
|
||||||
mode="bilinear"
|
mode="bilinear"
|
||||||
)
|
)
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
second_pass_noise = self.get_noise_like(resized_latents)
|
||||||
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
|
||||||
ddim_sampler.make_schedule(
|
|
||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
z_enc = ddim_sampler.stochastic_encode(
|
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||||
samples,
|
resized_latents,
|
||||||
torch.tensor([t_enc-1]).to(self.model.device),
|
num_inference_steps=steps,
|
||||||
noise=self.get_noise(width,height,False)
|
conditioning_data=conditioning_data,
|
||||||
)
|
strength=strength,
|
||||||
|
noise=second_pass_noise,
|
||||||
|
callback=step_callback)
|
||||||
|
|
||||||
# decode it
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
samples = ddim_sampler.decode(
|
|
||||||
z_enc,
|
|
||||||
c,
|
|
||||||
t_enc,
|
|
||||||
img_callback = step_callback,
|
|
||||||
unconditional_guidance_scale=cfg_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
|
||||||
all_timesteps_count=steps
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.free_gpu_mem:
|
|
||||||
self.model.model.to('cpu')
|
|
||||||
self.model.cond_stage_model.device = 'cpu'
|
|
||||||
self.model.cond_stage_model.to('cpu')
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return self.sample_to_image(samples)
|
# FIXME: do we really need something entirely different for the inpainting model?
|
||||||
|
|
||||||
# in the case of the inpainting model being loaded, the trick of
|
# in the case of the inpainting model being loaded, the trick of
|
||||||
# providing an interpolated latent doesn't work, so we transiently
|
# providing an interpolated latent doesn't work, so we transiently
|
||||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||||
# over it in img2img mode. Because the inpaing model is so conservative
|
# over it in img2img mode. Because the inpaing model is so conservative
|
||||||
# it doesn't change the image (much)
|
# it doesn't change the image (much)
|
||||||
def inpaint_make_image(x_T):
|
|
||||||
omnibus = Omnibus(self.model,self.precision)
|
|
||||||
result = omnibus.generate(
|
|
||||||
prompt,
|
|
||||||
sampler=sampler,
|
|
||||||
width=init_width,
|
|
||||||
height=init_height,
|
|
||||||
step_callback=step_callback,
|
|
||||||
steps = steps,
|
|
||||||
cfg_scale = cfg_scale,
|
|
||||||
ddim_eta = ddim_eta,
|
|
||||||
conditioning = conditioning,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
assert result is not None and len(result)>0,'** txt2img failed **'
|
|
||||||
image = result[0][0]
|
|
||||||
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
|
|
||||||
print(kwargs.pop('init_image',None))
|
|
||||||
result = omnibus.generate(
|
|
||||||
prompt,
|
|
||||||
sampler=sampler,
|
|
||||||
init_image=interpolated_image,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
seed=result[0][1],
|
|
||||||
step_callback=step_callback,
|
|
||||||
steps = steps,
|
|
||||||
cfg_scale = cfg_scale,
|
|
||||||
ddim_eta = ddim_eta,
|
|
||||||
conditioning = conditioning,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
return result[0][0]
|
|
||||||
|
|
||||||
if sampler.uses_inpainting_model():
|
|
||||||
return inpaint_make_image
|
|
||||||
else:
|
|
||||||
return make_image
|
return make_image
|
||||||
|
|
||||||
|
def get_noise_like(self, like: torch.Tensor):
|
||||||
|
device = like.device
|
||||||
|
if device.type == 'mps':
|
||||||
|
x = torch.randn_like(like, device='cpu').to(device)
|
||||||
|
else:
|
||||||
|
x = torch.randn_like(like, device=device)
|
||||||
|
if self.perlin > 0.0:
|
||||||
|
shape = like.shape
|
||||||
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
|
return x
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
def get_noise(self,width,height,scale = True):
|
def get_noise(self,width,height,scale = True):
|
||||||
# print(f"Get noise: {width}x{height}")
|
# print(f"Get noise: {width}x{height}")
|
||||||
@ -179,4 +124,3 @@ class Txt2Img2Img(Generator):
|
|||||||
scaled_height // self.downsampling_factor,
|
scaled_height // self.downsampling_factor,
|
||||||
scaled_width // self.downsampling_factor],
|
scaled_width // self.downsampling_factor],
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
@ -8,11 +8,14 @@ the attributes:
|
|||||||
- root - the root directory under which "models" and "outputs" can be found
|
- root - the root directory under which "models" and "outputs" can be found
|
||||||
- initfile - path to the initialization file
|
- initfile - path to the initialization file
|
||||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||||
|
- always_use_cpu - force use of CPU even if GPU is available
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
from pathlib import Path
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
Globals = Namespace()
|
Globals = Namespace()
|
||||||
|
|
||||||
@ -26,6 +29,41 @@ else:
|
|||||||
|
|
||||||
# Where to look for the initialization file
|
# Where to look for the initialization file
|
||||||
Globals.initfile = 'invokeai.init'
|
Globals.initfile = 'invokeai.init'
|
||||||
|
Globals.models_dir = 'models'
|
||||||
|
Globals.config_dir = 'configs'
|
||||||
|
Globals.autoscan_dir = 'weights'
|
||||||
|
|
||||||
# Try loading patchmatch
|
# Try loading patchmatch
|
||||||
Globals.try_patchmatch = True
|
Globals.try_patchmatch = True
|
||||||
|
|
||||||
|
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
||||||
|
Globals.always_use_cpu = False
|
||||||
|
|
||||||
|
# Whether the internet is reachable for dynamic downloads
|
||||||
|
# The CLI will test connectivity at startup time.
|
||||||
|
Globals.internet_available = True
|
||||||
|
|
||||||
|
def global_config_dir()->Path:
|
||||||
|
return Path(Globals.root, Globals.config_dir)
|
||||||
|
|
||||||
|
def global_models_dir()->Path:
|
||||||
|
return Path(Globals.root, Globals.models_dir)
|
||||||
|
|
||||||
|
def global_autoscan_dir()->Path:
|
||||||
|
return Path(Globals.root, Globals.autoscan_dir)
|
||||||
|
|
||||||
|
def global_set_root(root_dir:Union[str,Path]):
|
||||||
|
Globals.root = root_dir
|
||||||
|
|
||||||
|
def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
||||||
|
'''
|
||||||
|
Returns Path to the model cache directory. If a subdirectory
|
||||||
|
is provided, it will be appended to the end of the path, allowing
|
||||||
|
for huggingface-style conventions:
|
||||||
|
global_cache_dir('diffusers')
|
||||||
|
global_cache_dir('transformers')
|
||||||
|
'''
|
||||||
|
if (home := os.environ.get('HF_HOME')):
|
||||||
|
return Path(home,subdir)
|
||||||
|
else:
|
||||||
|
return Path(Globals.root,'models',subdir)
|
||||||
|
@ -1,451 +0,0 @@
|
|||||||
'''
|
|
||||||
Manage a cache of Stable Diffusion model files for fast switching.
|
|
||||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
|
||||||
below a preset minimum, the least recently used model will be
|
|
||||||
cleared and loaded from disk when next needed.
|
|
||||||
'''
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import time
|
|
||||||
import gc
|
|
||||||
import hashlib
|
|
||||||
import psutil
|
|
||||||
import sys
|
|
||||||
import transformers
|
|
||||||
import traceback
|
|
||||||
import textwrap
|
|
||||||
import contextlib
|
|
||||||
from typing import Union
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from omegaconf.errors import ConfigAttributeError
|
|
||||||
from ldm.util import instantiate_from_config, ask_user
|
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS=2
|
|
||||||
|
|
||||||
class ModelCache(object):
|
|
||||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
|
||||||
'''
|
|
||||||
Initialize with the path to the models.yaml config file,
|
|
||||||
the torch device type, and precision. The optional
|
|
||||||
min_avail_mem argument specifies how much unused system
|
|
||||||
(CPU) memory to preserve. The cache of models in RAM will
|
|
||||||
grow until this value is approached. Default is 2G.
|
|
||||||
'''
|
|
||||||
# prevent nasty-looking CLIP log message
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
self.config = config
|
|
||||||
self.precision = precision
|
|
||||||
self.device = torch.device(device_type)
|
|
||||||
self.max_loaded_models = max_loaded_models
|
|
||||||
self.models = {}
|
|
||||||
self.stack = [] # this is an LRU FIFO
|
|
||||||
self.current_model = None
|
|
||||||
|
|
||||||
def valid_model(self, model_name:str)->bool:
|
|
||||||
'''
|
|
||||||
Given a model name, returns True if it is a valid
|
|
||||||
identifier.
|
|
||||||
'''
|
|
||||||
return model_name in self.config
|
|
||||||
|
|
||||||
def get_model(self, model_name:str):
|
|
||||||
'''
|
|
||||||
Given a model named identified in models.yaml, return
|
|
||||||
the model object. If in RAM will load into GPU VRAM.
|
|
||||||
If on disk, will load from there.
|
|
||||||
'''
|
|
||||||
if not self.valid_model(model_name):
|
|
||||||
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
|
||||||
return self.current_model
|
|
||||||
|
|
||||||
if self.current_model != model_name:
|
|
||||||
if model_name not in self.models: # make room for a new one
|
|
||||||
self._make_cache_room()
|
|
||||||
self.offload_model(self.current_model)
|
|
||||||
|
|
||||||
if model_name in self.models:
|
|
||||||
requested_model = self.models[model_name]['model']
|
|
||||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
|
||||||
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
|
||||||
width = self.models[model_name]['width']
|
|
||||||
height = self.models[model_name]['height']
|
|
||||||
hash = self.models[model_name]['hash']
|
|
||||||
|
|
||||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
|
||||||
try:
|
|
||||||
requested_model, width, height, hash = self._load_model(model_name)
|
|
||||||
self.models[model_name] = {
|
|
||||||
'model': requested_model,
|
|
||||||
'width': width,
|
|
||||||
'height': height,
|
|
||||||
'hash': hash,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f'** model {model_name} could not be loaded: {str(e)}')
|
|
||||||
print(traceback.format_exc())
|
|
||||||
assert self.current_model,'** FATAL: no current model to restore to'
|
|
||||||
print(f'** restoring {self.current_model}')
|
|
||||||
self.get_model(self.current_model)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.current_model = model_name
|
|
||||||
self._push_newest_model(model_name)
|
|
||||||
return {
|
|
||||||
'model':requested_model,
|
|
||||||
'width':width,
|
|
||||||
'height':height,
|
|
||||||
'hash': hash
|
|
||||||
}
|
|
||||||
|
|
||||||
def default_model(self) -> str:
|
|
||||||
'''
|
|
||||||
Returns the name of the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
'''
|
|
||||||
for model_name in self.config:
|
|
||||||
if self.config[model_name].get('default'):
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
def set_default_model(self,model_name:str) -> None:
|
|
||||||
'''
|
|
||||||
Set the default model. The change will not take
|
|
||||||
effect until you call model_cache.commit()
|
|
||||||
'''
|
|
||||||
assert model_name in self.models,f"unknown model '{model_name}'"
|
|
||||||
|
|
||||||
config = self.config
|
|
||||||
for model in config:
|
|
||||||
config[model].pop('default',None)
|
|
||||||
config[model_name]['default'] = True
|
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
|
||||||
'''
|
|
||||||
Return a dict of models in the format:
|
|
||||||
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
|
||||||
'description': description,
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
'''
|
|
||||||
models = {}
|
|
||||||
for name in self.config:
|
|
||||||
description = self.config[name].description if 'description' in self.config[name] else '<no description>'
|
|
||||||
weights = self.config[name].weights if 'weights' in self.config[name] else '<no weights>'
|
|
||||||
config = self.config[name].config if 'config' in self.config[name] else '<no config>'
|
|
||||||
width = self.config[name].width if 'width' in self.config[name] else 512
|
|
||||||
height = self.config[name].height if 'height' in self.config[name] else 512
|
|
||||||
default = self.config[name].default if 'default' in self.config[name] else False
|
|
||||||
vae = self.config[name].vae if 'vae' in self.config[name] else '<no vae>'
|
|
||||||
|
|
||||||
if self.current_model == name:
|
|
||||||
status = 'active'
|
|
||||||
elif name in self.models:
|
|
||||||
status = 'cached'
|
|
||||||
else:
|
|
||||||
status = 'not loaded'
|
|
||||||
|
|
||||||
models[name]={
|
|
||||||
'status' : status,
|
|
||||||
'description' : description,
|
|
||||||
'weights': weights,
|
|
||||||
'config': config,
|
|
||||||
'width': width,
|
|
||||||
'height': height,
|
|
||||||
'vae': vae,
|
|
||||||
'default': default
|
|
||||||
}
|
|
||||||
return models
|
|
||||||
|
|
||||||
def print_models(self) -> None:
|
|
||||||
'''
|
|
||||||
Print a table of models, their descriptions, and load status
|
|
||||||
'''
|
|
||||||
models = self.list_models()
|
|
||||||
for name in models:
|
|
||||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}'
|
|
||||||
if models[name]['status'] == 'active':
|
|
||||||
line = f'\033[1m{line}\033[0m'
|
|
||||||
print(line)
|
|
||||||
|
|
||||||
def del_model(self, model_name:str) -> None:
|
|
||||||
'''
|
|
||||||
Delete the named model.
|
|
||||||
'''
|
|
||||||
omega = self.config
|
|
||||||
del omega[model_name]
|
|
||||||
if model_name in self.stack:
|
|
||||||
self.stack.remove(model_name)
|
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False) -> None:
|
|
||||||
'''
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory and the
|
|
||||||
method will return True. Will fail with an assertion error if provided
|
|
||||||
attributes are incorrect or the model name is missing.
|
|
||||||
'''
|
|
||||||
omega = self.config
|
|
||||||
for field in ('description','weights','height','width','config'):
|
|
||||||
assert field in model_attributes, f'required field {field} is missing'
|
|
||||||
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
|
||||||
|
|
||||||
config = omega[model_name] if model_name in omega else {}
|
|
||||||
for field in model_attributes:
|
|
||||||
if field == 'weights':
|
|
||||||
field.replace('\\', '/')
|
|
||||||
config[field] = model_attributes[field]
|
|
||||||
|
|
||||||
omega[model_name] = config
|
|
||||||
if clobber:
|
|
||||||
self._invalidate_cached_model(model_name)
|
|
||||||
|
|
||||||
def _load_model(self, model_name:str):
|
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
|
||||||
if model_name not in self.config:
|
|
||||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
|
||||||
|
|
||||||
mconfig = self.config[model_name]
|
|
||||||
config = mconfig.config
|
|
||||||
weights = mconfig.weights
|
|
||||||
vae = mconfig.get('vae')
|
|
||||||
width = mconfig.width
|
|
||||||
height = mconfig.height
|
|
||||||
|
|
||||||
if not os.path.isabs(weights):
|
|
||||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
|
||||||
# scan model
|
|
||||||
self.scan_model(model_name, weights)
|
|
||||||
|
|
||||||
print(f'>> Loading {model_name} from {weights}')
|
|
||||||
|
|
||||||
# for usage statistics
|
|
||||||
if self._has_cuda():
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
# this does the work
|
|
||||||
if not os.path.isabs(config):
|
|
||||||
config = os.path.join(Globals.root,config)
|
|
||||||
omega_config = OmegaConf.load(config)
|
|
||||||
with open(weights,'rb') as f:
|
|
||||||
weight_bytes = f.read()
|
|
||||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
|
||||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
|
||||||
del weight_bytes
|
|
||||||
# merged models from auto11 merge board are flat for some reason
|
|
||||||
if 'state_dict' in sd:
|
|
||||||
sd = sd['state_dict']
|
|
||||||
|
|
||||||
print(f' | Forcing garbage collection prior to loading new model')
|
|
||||||
gc.collect()
|
|
||||||
model = instantiate_from_config(omega_config.model)
|
|
||||||
model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
if self.precision == 'float16':
|
|
||||||
print(' | Using faster float16 precision')
|
|
||||||
model.to(torch.float16)
|
|
||||||
else:
|
|
||||||
print(' | Using more accurate float32 precision')
|
|
||||||
|
|
||||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
|
||||||
if vae:
|
|
||||||
if not os.path.isabs(vae):
|
|
||||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
|
||||||
if os.path.exists(vae):
|
|
||||||
print(f' | Loading VAE weights from: {vae}')
|
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
|
||||||
else:
|
|
||||||
print(f' | VAE file {vae} not found. Skipping.')
|
|
||||||
|
|
||||||
model.to(self.device)
|
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
|
||||||
model.cond_stage_model.device = self.device
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
|
||||||
module._orig_padding_mode = module.padding_mode
|
|
||||||
|
|
||||||
# usage statistics
|
|
||||||
toc = time.time()
|
|
||||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
|
||||||
|
|
||||||
if self._has_cuda():
|
|
||||||
print(
|
|
||||||
'>> Max VRAM used to load the model:',
|
|
||||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
|
||||||
'\n>> Current VRAM usage:'
|
|
||||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
|
||||||
)
|
|
||||||
|
|
||||||
return model, width, height, model_hash
|
|
||||||
|
|
||||||
def offload_model(self, model_name:str) -> None:
|
|
||||||
'''
|
|
||||||
Offload the indicated model to CPU. Will call
|
|
||||||
_make_cache_room() to free space if needed.
|
|
||||||
'''
|
|
||||||
if model_name not in self.models:
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f'>> Offloading {model_name} to CPU')
|
|
||||||
model = self.models[model_name]['model']
|
|
||||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
if self._has_cuda():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def scan_model(self, model_name, checkpoint):
|
|
||||||
# scan model
|
|
||||||
print(f'>> Scanning Model: {model_name}')
|
|
||||||
scan_result = scan_file_path(checkpoint)
|
|
||||||
if scan_result.infected_files != 0:
|
|
||||||
if scan_result.infected_files == 1:
|
|
||||||
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
|
|
||||||
print('### WARNING: The model you are trying to load seems to be infected.')
|
|
||||||
print('### For your safety, InvokeAI will not load this model.')
|
|
||||||
print('### Please use checkpoints from trusted sources.')
|
|
||||||
print("### Exiting InvokeAI")
|
|
||||||
sys.exit()
|
|
||||||
else:
|
|
||||||
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
|
||||||
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
|
||||||
if model_safe_check_fail.lower() != 'y':
|
|
||||||
print("### Exiting InvokeAI")
|
|
||||||
sys.exit()
|
|
||||||
else:
|
|
||||||
print('>> Model Scanned. OK!!')
|
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
|
||||||
|
|
||||||
print(f'>> Finding Models In: {search_folder}')
|
|
||||||
models_folder = Path(search_folder).glob('**/*.ckpt')
|
|
||||||
|
|
||||||
files = [x for x in models_folder if x.is_file()]
|
|
||||||
|
|
||||||
found_models = []
|
|
||||||
for file in files:
|
|
||||||
found_models.append({
|
|
||||||
'name': file.stem,
|
|
||||||
'location': str(file.resolve()).replace('\\', '/')
|
|
||||||
})
|
|
||||||
|
|
||||||
return search_folder, found_models
|
|
||||||
|
|
||||||
def _make_cache_room(self) -> None:
|
|
||||||
num_loaded_models = len(self.models)
|
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
|
||||||
least_recent_model = self._pop_oldest_model()
|
|
||||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
|
||||||
if least_recent_model is not None:
|
|
||||||
del self.models[least_recent_model]
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
def print_vram_usage(self) -> None:
|
|
||||||
if self._has_cuda:
|
|
||||||
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
|
||||||
|
|
||||||
def commit(self,config_file_path:str) -> None:
|
|
||||||
'''
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
'''
|
|
||||||
yaml_str = OmegaConf.to_yaml(self.config)
|
|
||||||
if not os.path.isabs(config_file_path):
|
|
||||||
config_file_path = os.path.normpath(os.path.join(Globals.root,opt.conf))
|
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
|
||||||
with open(tmpfile, 'w') as outfile:
|
|
||||||
outfile.write(self.preamble())
|
|
||||||
outfile.write(yaml_str)
|
|
||||||
os.replace(tmpfile,config_file_path)
|
|
||||||
|
|
||||||
def preamble(self) -> str:
|
|
||||||
'''
|
|
||||||
Returns the preamble for the config file.
|
|
||||||
'''
|
|
||||||
return textwrap.dedent('''\
|
|
||||||
# This file describes the alternative machine learning models
|
|
||||||
# available to InvokeAI script.
|
|
||||||
#
|
|
||||||
# To add a new model, follow the examples below. Each
|
|
||||||
# model requires a model config file, a weights file,
|
|
||||||
# and the width and height of the images it
|
|
||||||
# was trained on.
|
|
||||||
''')
|
|
||||||
|
|
||||||
def _invalidate_cached_model(self,model_name:str) -> None:
|
|
||||||
self.offload_model(model_name)
|
|
||||||
if model_name in self.stack:
|
|
||||||
self.stack.remove(model_name)
|
|
||||||
self.models.pop(model_name,None)
|
|
||||||
|
|
||||||
def _model_to_cpu(self,model):
|
|
||||||
if self.device != 'cpu':
|
|
||||||
model.cond_stage_model.device = 'cpu'
|
|
||||||
model.first_stage_model.to('cpu')
|
|
||||||
model.cond_stage_model.to('cpu')
|
|
||||||
model.model.to('cpu')
|
|
||||||
return model.to('cpu')
|
|
||||||
else:
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _model_from_cpu(self,model):
|
|
||||||
if self.device != 'cpu':
|
|
||||||
model.to(self.device)
|
|
||||||
model.first_stage_model.to(self.device)
|
|
||||||
model.cond_stage_model.to(self.device)
|
|
||||||
model.cond_stage_model.device = self.device
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _pop_oldest_model(self):
|
|
||||||
'''
|
|
||||||
Remove the first element of the FIFO, which ought
|
|
||||||
to be the least recently accessed model. Do not
|
|
||||||
pop the last one, because it is in active use!
|
|
||||||
'''
|
|
||||||
return self.stack.pop(0)
|
|
||||||
|
|
||||||
def _push_newest_model(self,model_name:str) -> None:
|
|
||||||
'''
|
|
||||||
Maintain a simple FIFO. First element is always the
|
|
||||||
least recent, and last element is always the most recent.
|
|
||||||
'''
|
|
||||||
with contextlib.suppress(ValueError):
|
|
||||||
self.stack.remove(model_name)
|
|
||||||
self.stack.append(model_name)
|
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
|
||||||
return self.device.type == 'cuda'
|
|
||||||
|
|
||||||
def _cached_sha256(self,path,data) -> Union[str, bytes]:
|
|
||||||
dirname = os.path.dirname(path)
|
|
||||||
basename = os.path.basename(path)
|
|
||||||
base, _ = os.path.splitext(basename)
|
|
||||||
hashpath = os.path.join(dirname,base+'.sha256')
|
|
||||||
|
|
||||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
|
||||||
with open(hashpath) as f:
|
|
||||||
hash = f.read()
|
|
||||||
return hash
|
|
||||||
|
|
||||||
print(f'>> Calculating sha256 hash of weights file')
|
|
||||||
tic = time.time()
|
|
||||||
sha = hashlib.sha256()
|
|
||||||
sha.update(data)
|
|
||||||
hash = sha.hexdigest()
|
|
||||||
toc = time.time()
|
|
||||||
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
|
||||||
|
|
||||||
with open(hashpath,'w') as f:
|
|
||||||
f.write(hash)
|
|
||||||
return hash
|
|
953
ldm/invoke/model_manager.py
Normal file
953
ldm/invoke/model_manager.py
Normal file
@ -0,0 +1,953 @@
|
|||||||
|
'''
|
||||||
|
Manage a cache of Stable Diffusion model files for fast switching.
|
||||||
|
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||||
|
below a preset minimum, the least recently used model will be
|
||||||
|
cleared and loaded from disk when next needed.
|
||||||
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import gc
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
|
import safetensors.torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union, Any
|
||||||
|
from ldm.util import download_with_progress_bar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import safetensors
|
||||||
|
import transformers
|
||||||
|
from diffusers import AutoencoderKL, logging as dlogging
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ldm.invoke.globals import Globals, global_models_dir, global_autoscan_dir, global_cache_dir
|
||||||
|
from ldm.util import instantiate_from_config, ask_user
|
||||||
|
|
||||||
|
DEFAULT_MAX_MODELS=2
|
||||||
|
|
||||||
|
class ModelManager(object):
|
||||||
|
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||||
|
'''
|
||||||
|
Initialize with the path to the models.yaml config file,
|
||||||
|
the torch device type, and precision. The optional
|
||||||
|
min_avail_mem argument specifies how much unused system
|
||||||
|
(CPU) memory to preserve. The cache of models in RAM will
|
||||||
|
grow until this value is approached. Default is 2G.
|
||||||
|
'''
|
||||||
|
# prevent nasty-looking CLIP log message
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
self.config = config
|
||||||
|
self.precision = precision
|
||||||
|
self.device = torch.device(device_type)
|
||||||
|
self.max_loaded_models = max_loaded_models
|
||||||
|
self.models = {}
|
||||||
|
self.stack = [] # this is an LRU FIFO
|
||||||
|
self.current_model = None
|
||||||
|
|
||||||
|
def valid_model(self, model_name:str)->bool:
|
||||||
|
'''
|
||||||
|
Given a model name, returns True if it is a valid
|
||||||
|
identifier.
|
||||||
|
'''
|
||||||
|
return model_name in self.config
|
||||||
|
|
||||||
|
def get_model(self, model_name:str):
|
||||||
|
'''
|
||||||
|
Given a model named identified in models.yaml, return
|
||||||
|
the model object. If in RAM will load into GPU VRAM.
|
||||||
|
If on disk, will load from there.
|
||||||
|
'''
|
||||||
|
if not self.valid_model(model_name):
|
||||||
|
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
|
return self.current_model
|
||||||
|
|
||||||
|
if self.current_model != model_name:
|
||||||
|
if model_name not in self.models: # make room for a new one
|
||||||
|
self._make_cache_room()
|
||||||
|
self.offload_model(self.current_model)
|
||||||
|
|
||||||
|
if model_name in self.models:
|
||||||
|
requested_model = self.models[model_name]['model']
|
||||||
|
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||||
|
self.models[model_name]['model'] = self._model_from_cpu(requested_model)
|
||||||
|
width = self.models[model_name]['width']
|
||||||
|
height = self.models[model_name]['height']
|
||||||
|
hash = self.models[model_name]['hash']
|
||||||
|
|
||||||
|
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||||
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
|
self.models[model_name] = {
|
||||||
|
'model': requested_model,
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'hash': hash,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current_model = model_name
|
||||||
|
self._push_newest_model(model_name)
|
||||||
|
return {
|
||||||
|
'model':requested_model,
|
||||||
|
'width':width,
|
||||||
|
'height':height,
|
||||||
|
'hash': hash
|
||||||
|
}
|
||||||
|
|
||||||
|
def default_model(self) -> str | None:
|
||||||
|
'''
|
||||||
|
Returns the name of the default model, or None
|
||||||
|
if none is defined.
|
||||||
|
'''
|
||||||
|
for model_name in self.config:
|
||||||
|
if self.config[model_name].get('default'):
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def set_default_model(self,model_name:str) -> None:
|
||||||
|
'''
|
||||||
|
Set the default model. The change will not take
|
||||||
|
effect until you call model_manager.commit()
|
||||||
|
'''
|
||||||
|
assert model_name in self.models,f"unknown model '{model_name}'"
|
||||||
|
|
||||||
|
config = self.config
|
||||||
|
for model in config:
|
||||||
|
config[model].pop('default',None)
|
||||||
|
config[model_name]['default'] = True
|
||||||
|
|
||||||
|
def model_info(self, model_name:str)->dict:
|
||||||
|
'''
|
||||||
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
|
'''
|
||||||
|
if model_name not in self.config:
|
||||||
|
return None
|
||||||
|
return self.config[model_name]
|
||||||
|
|
||||||
|
def model_names(self)->list[str]:
|
||||||
|
'''
|
||||||
|
Return a list consisting of all the names of models defined in models.yaml
|
||||||
|
'''
|
||||||
|
return list(self.config.keys())
|
||||||
|
|
||||||
|
def is_legacy(self,model_name:str)->bool:
|
||||||
|
'''
|
||||||
|
Return true if this is a legacy (.ckpt) model
|
||||||
|
'''
|
||||||
|
info = self.model_info(model_name)
|
||||||
|
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self) -> dict:
|
||||||
|
'''
|
||||||
|
Return a dict of models in the format:
|
||||||
|
{ model_name1: {'status': ('active'|'cached'|'not loaded'),
|
||||||
|
'description': description,
|
||||||
|
'format': ('ckpt'|'diffusers'|'vae'),
|
||||||
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
Please use model_manager.models() to get all the model names,
|
||||||
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
|
named 'model-name', and model_manager.config to get the full OmegaConf
|
||||||
|
object derived from models.yaml
|
||||||
|
'''
|
||||||
|
models = {}
|
||||||
|
for name in sorted(self.config):
|
||||||
|
stanza = self.config[name]
|
||||||
|
|
||||||
|
# don't include VAEs in listing (legacy style)
|
||||||
|
if 'config' in stanza and '/VAE/' in stanza['config']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
models[name] = dict()
|
||||||
|
format = stanza.get('format','ckpt') # Determine Format
|
||||||
|
|
||||||
|
# Common Attribs
|
||||||
|
description = stanza.get('description', None)
|
||||||
|
if self.current_model == name:
|
||||||
|
status = 'active'
|
||||||
|
elif name in self.models:
|
||||||
|
status = 'cached'
|
||||||
|
else:
|
||||||
|
status = 'not loaded'
|
||||||
|
models[name].update(
|
||||||
|
description = description,
|
||||||
|
format = format,
|
||||||
|
status = status,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checkpoint Config Parse
|
||||||
|
if format == 'ckpt':
|
||||||
|
models[name].update(
|
||||||
|
config = str(stanza.get('config', None)),
|
||||||
|
weights = str(stanza.get('weights', None)),
|
||||||
|
vae = str(stanza.get('vae', None)),
|
||||||
|
width = str(stanza.get('width', 512)),
|
||||||
|
height = str(stanza.get('height', 512)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Diffusers Config Parse
|
||||||
|
if (vae := stanza.get('vae',None)):
|
||||||
|
if isinstance(vae,DictConfig):
|
||||||
|
vae = dict(
|
||||||
|
repo_id = str(vae.get('repo_id',None)),
|
||||||
|
path = str(vae.get('path',None)),
|
||||||
|
subfolder = str(vae.get('subfolder',None))
|
||||||
|
)
|
||||||
|
|
||||||
|
if format == 'diffusers':
|
||||||
|
models[name].update(
|
||||||
|
vae = vae,
|
||||||
|
repo_id = str(stanza.get('repo_id', None)),
|
||||||
|
path = str(stanza.get('path',None)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
'''
|
||||||
|
Print a table of models, their descriptions, and load status
|
||||||
|
'''
|
||||||
|
models = self.list_models()
|
||||||
|
for name in models:
|
||||||
|
if models[name]['format'] == 'vae':
|
||||||
|
continue
|
||||||
|
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["format"]:10s} {models[name]["description"]}'
|
||||||
|
if models[name]['status'] == 'active':
|
||||||
|
line = f'\033[1m{line}\033[0m'
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
def del_model(self, model_name:str) -> None:
|
||||||
|
'''
|
||||||
|
Delete the named model.
|
||||||
|
'''
|
||||||
|
omega = self.config
|
||||||
|
del omega[model_name]
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
|
||||||
|
def add_model(self, model_name:str, model_attributes:dict, clobber:bool=False) -> None:
|
||||||
|
'''
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
|
On a successful update, the config will be changed in memory and the
|
||||||
|
method will return True. Will fail with an assertion error if provided
|
||||||
|
attributes are incorrect or the model name is missing.
|
||||||
|
'''
|
||||||
|
omega = self.config
|
||||||
|
assert 'format' in model_attributes, 'missing required field "format"'
|
||||||
|
if model_attributes['format']=='diffusers':
|
||||||
|
assert 'description' in model_attributes, 'required field "description" is missing'
|
||||||
|
assert 'path' in model_attributes or 'repo_id' in model_attributes,'model must have either the "path" or "repo_id" fields defined'
|
||||||
|
else:
|
||||||
|
for field in ('description','weights','height','width','config'):
|
||||||
|
assert field in model_attributes, f'required field {field} is missing'
|
||||||
|
|
||||||
|
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
|
||||||
|
|
||||||
|
if model_name not in omega:
|
||||||
|
omega[model_name] = dict()
|
||||||
|
OmegaConf.update(omega,model_name,model_attributes,merge=False)
|
||||||
|
if 'weights' in omega[model_name]:
|
||||||
|
omega[model_name]['weights'].replace('\\','/')
|
||||||
|
|
||||||
|
if clobber:
|
||||||
|
self._invalidate_cached_model(model_name)
|
||||||
|
|
||||||
|
def _load_model(self, model_name:str):
|
||||||
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
|
if model_name not in self.config:
|
||||||
|
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
|
return
|
||||||
|
|
||||||
|
mconfig = self.config[model_name]
|
||||||
|
|
||||||
|
# for usage statistics
|
||||||
|
if self._has_cuda():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
# this does the work
|
||||||
|
model_format = mconfig.get('format', 'ckpt')
|
||||||
|
if model_format == 'ckpt':
|
||||||
|
weights = mconfig.weights
|
||||||
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
|
model, width, height, model_hash = self._load_ckpt_model(model_name, mconfig)
|
||||||
|
elif model_format == 'diffusers':
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
|
||||||
|
|
||||||
|
# usage statistics
|
||||||
|
toc = time.time()
|
||||||
|
print('>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||||
|
if self._has_cuda():
|
||||||
|
print(
|
||||||
|
'>> Max VRAM used to load the model:',
|
||||||
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
|
'\n>> Current VRAM usage:'
|
||||||
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
|
)
|
||||||
|
return model, width, height, model_hash
|
||||||
|
|
||||||
|
def _load_ckpt_model(self, model_name, mconfig):
|
||||||
|
config = mconfig.config
|
||||||
|
weights = mconfig.weights
|
||||||
|
vae = mconfig.get('vae')
|
||||||
|
width = mconfig.width
|
||||||
|
height = mconfig.height
|
||||||
|
|
||||||
|
if not os.path.isabs(config):
|
||||||
|
config = os.path.join(Globals.root,config)
|
||||||
|
if not os.path.isabs(weights):
|
||||||
|
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||||
|
# scan model
|
||||||
|
self.scan_model(model_name, weights)
|
||||||
|
|
||||||
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
|
|
||||||
|
# for usage statistics
|
||||||
|
if self._has_cuda():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
# this does the work
|
||||||
|
if not os.path.isabs(config):
|
||||||
|
config = os.path.join(Globals.root,config)
|
||||||
|
omega_config = OmegaConf.load(config)
|
||||||
|
with open(weights,'rb') as f:
|
||||||
|
weight_bytes = f.read()
|
||||||
|
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||||
|
sd = None
|
||||||
|
if weights.endswith('.safetensors'):
|
||||||
|
sd = safetensors.torch.load(weight_bytes)
|
||||||
|
else:
|
||||||
|
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||||
|
del weight_bytes
|
||||||
|
# merged models from auto11 merge board are flat for some reason
|
||||||
|
if 'state_dict' in sd:
|
||||||
|
sd = sd['state_dict']
|
||||||
|
|
||||||
|
print(' | Forcing garbage collection prior to loading new model')
|
||||||
|
gc.collect()
|
||||||
|
model = instantiate_from_config(omega_config.model)
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
if self.precision == 'float16':
|
||||||
|
print(' | Using faster float16 precision')
|
||||||
|
model.to(torch.float16)
|
||||||
|
else:
|
||||||
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||||
|
if vae:
|
||||||
|
if not os.path.isabs(vae):
|
||||||
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||||
|
if os.path.exists(vae):
|
||||||
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
|
else:
|
||||||
|
print(f' | VAE file {vae} not found. Skipping.')
|
||||||
|
|
||||||
|
model.to(self.device)
|
||||||
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||||
|
module._orig_padding_mode = module.padding_mode
|
||||||
|
|
||||||
|
# usage statistics
|
||||||
|
toc = time.time()
|
||||||
|
print('>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||||
|
|
||||||
|
if self._has_cuda():
|
||||||
|
print(
|
||||||
|
'>> Max VRAM used to load the model:',
|
||||||
|
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
|
'\n>> Current VRAM usage:'
|
||||||
|
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, width, height, model_hash
|
||||||
|
|
||||||
|
def _load_diffusers_model(self, mconfig):
|
||||||
|
name_or_path = self.model_name_or_path(mconfig)
|
||||||
|
using_fp16 = self.precision == 'float16'
|
||||||
|
|
||||||
|
print(f'>> Loading diffusers model from {name_or_path}')
|
||||||
|
if using_fp16:
|
||||||
|
print(' | Using faster float16 precision')
|
||||||
|
else:
|
||||||
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
|
# TODO: scan weights maybe?
|
||||||
|
pipeline_args: dict[str, Any] = dict(
|
||||||
|
safety_checker=None,
|
||||||
|
local_files_only=not Globals.internet_available
|
||||||
|
)
|
||||||
|
if 'vae' in mconfig:
|
||||||
|
vae = self._load_vae(mconfig['vae'])
|
||||||
|
pipeline_args.update(vae=vae)
|
||||||
|
if not isinstance(name_or_path,Path):
|
||||||
|
pipeline_args.update(cache_dir=global_cache_dir('diffusers'))
|
||||||
|
if using_fp16:
|
||||||
|
pipeline_args.update(torch_dtype=torch.float16)
|
||||||
|
fp_args_list = [{'revision':'fp16'},{}]
|
||||||
|
else:
|
||||||
|
fp_args_list = [{}]
|
||||||
|
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
|
pipeline = None
|
||||||
|
for fp_args in fp_args_list:
|
||||||
|
try:
|
||||||
|
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||||
|
name_or_path,
|
||||||
|
**pipeline_args,
|
||||||
|
**fp_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
except OSError as e:
|
||||||
|
if str(e).startswith('fp16 is not a valid'):
|
||||||
|
print(f'Could not fetch half-precision version of model {name_or_path}; fetching full-precision instead')
|
||||||
|
else:
|
||||||
|
print(f'An unexpected error occurred while downloading the model: {e})')
|
||||||
|
if pipeline:
|
||||||
|
break
|
||||||
|
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
|
||||||
|
|
||||||
|
pipeline.to(self.device)
|
||||||
|
|
||||||
|
model_hash = self._diffuser_sha256(name_or_path)
|
||||||
|
|
||||||
|
# square images???
|
||||||
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
|
height = width
|
||||||
|
|
||||||
|
print(f' | Default image dimensions = {width} x {height}')
|
||||||
|
|
||||||
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
|
def model_name_or_path(self, model_name:Union[str,DictConfig]) -> str | Path:
|
||||||
|
if isinstance(model_name,DictConfig):
|
||||||
|
mconfig = model_name
|
||||||
|
elif model_name in self.config:
|
||||||
|
mconfig = self.config[model_name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||||
|
|
||||||
|
if 'path' in mconfig:
|
||||||
|
path = Path(mconfig['path'])
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = Path(Globals.root, path).resolve()
|
||||||
|
return path
|
||||||
|
elif 'repo_id' in mconfig:
|
||||||
|
return mconfig['repo_id']
|
||||||
|
else:
|
||||||
|
raise ValueError("Model config must specify either repo_id or path.")
|
||||||
|
|
||||||
|
def offload_model(self, model_name:str) -> None:
|
||||||
|
'''
|
||||||
|
Offload the indicated model to CPU. Will call
|
||||||
|
_make_cache_room() to free space if needed.
|
||||||
|
'''
|
||||||
|
if model_name not in self.models:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f'>> Offloading {model_name} to CPU')
|
||||||
|
model = self.models[model_name]['model']
|
||||||
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if self._has_cuda():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def scan_model(self, model_name, checkpoint):
|
||||||
|
'''
|
||||||
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
|
and option to exit if an infected file is identified.
|
||||||
|
'''
|
||||||
|
# scan model
|
||||||
|
print(f'>> Scanning Model: {model_name}')
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
if scan_result.infected_files == 1:
|
||||||
|
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
|
||||||
|
print('### WARNING: The model you are trying to load seems to be infected.')
|
||||||
|
print('### For your safety, InvokeAI will not load this model.')
|
||||||
|
print('### Please use checkpoints from trusted sources.')
|
||||||
|
print("### Exiting InvokeAI")
|
||||||
|
sys.exit()
|
||||||
|
else:
|
||||||
|
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
||||||
|
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
||||||
|
if model_safe_check_fail.lower() != 'y':
|
||||||
|
print("### Exiting InvokeAI")
|
||||||
|
sys.exit()
|
||||||
|
else:
|
||||||
|
print('>> Model scanned ok!')
|
||||||
|
|
||||||
|
def import_diffuser_model(self,
|
||||||
|
repo_or_path:Union[str,Path],
|
||||||
|
model_name:str=None,
|
||||||
|
description:str=None,
|
||||||
|
commit_to_conf:Path=None,
|
||||||
|
)->bool:
|
||||||
|
'''
|
||||||
|
Attempts to install the indicated diffuser model and returns True if successful.
|
||||||
|
|
||||||
|
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
|
||||||
|
top of a downloaded diffusers directory.
|
||||||
|
|
||||||
|
You can optionally provide a model name and/or description. If not provided,
|
||||||
|
then these will be derived from the repo name. If you provide a commit_to_conf
|
||||||
|
path to the configuration file, then the new entry will be committed to the
|
||||||
|
models.yaml file.
|
||||||
|
'''
|
||||||
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
|
description = description or f'imported diffusers model {model_name}'
|
||||||
|
new_config = dict(
|
||||||
|
description=description,
|
||||||
|
format='diffusers',
|
||||||
|
)
|
||||||
|
if isinstance(repo_or_path,Path) and repo_or_path.exists():
|
||||||
|
new_config.update(path=repo_or_path)
|
||||||
|
else:
|
||||||
|
new_config.update(repo_id=repo_or_path)
|
||||||
|
|
||||||
|
self.add_model(model_name, new_config, True)
|
||||||
|
if commit_to_conf:
|
||||||
|
self.commit(commit_to_conf)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def import_ckpt_model(self,
|
||||||
|
weights:Union[str,Path],
|
||||||
|
config:Union[str,Path]='configs/stable-diffusion/v1-inference.yaml',
|
||||||
|
model_name:str=None,
|
||||||
|
model_description:str=None,
|
||||||
|
commit_to_conf:Path=None,
|
||||||
|
)->bool:
|
||||||
|
'''
|
||||||
|
Attempts to install the indicated ckpt file and returns True if successful.
|
||||||
|
|
||||||
|
"weights" can be either a path-like object corresponding to a local .ckpt file
|
||||||
|
or a http/https URL pointing to a remote model.
|
||||||
|
|
||||||
|
"config" is the model config file to use with this ckpt file. It defaults to
|
||||||
|
v1-inference.yaml. If a URL is provided, the config will be downloaded.
|
||||||
|
|
||||||
|
You can optionally provide a model name and/or description. If not provided,
|
||||||
|
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||||
|
path to the configuration file, then the new entry will be committed to the
|
||||||
|
models.yaml file.
|
||||||
|
'''
|
||||||
|
weights_path = self._resolve_path(weights,'models/ldm/stable-diffusion-v1')
|
||||||
|
config_path = self._resolve_path(config,'configs/stable-diffusion')
|
||||||
|
|
||||||
|
if weights_path is None or not weights_path.exists():
|
||||||
|
return False
|
||||||
|
if config_path is None or not config_path.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_name = model_name or Path(weights).stem
|
||||||
|
model_description = model_description or f'imported stable diffusion weights file {model_name}'
|
||||||
|
new_config = dict(
|
||||||
|
weights=str(weights_path),
|
||||||
|
config=str(config_path),
|
||||||
|
description=model_description,
|
||||||
|
format='ckpt',
|
||||||
|
width=512,
|
||||||
|
height=512
|
||||||
|
)
|
||||||
|
self.add_model(model_name, new_config, True)
|
||||||
|
if commit_to_conf:
|
||||||
|
self.commit(commit_to_conf)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def autoconvert_weights(
|
||||||
|
self,
|
||||||
|
conf_path:Path,
|
||||||
|
weights_directory:Path=None,
|
||||||
|
dest_directory:Path=None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Scan the indicated directory for .ckpt files, convert into diffuser models,
|
||||||
|
and import.
|
||||||
|
'''
|
||||||
|
weights_directory = weights_directory or global_autoscan_dir()
|
||||||
|
dest_directory = dest_directory or Path(global_models_dir(), 'optimized-ckpts')
|
||||||
|
|
||||||
|
print('>> Checking for unconverted .ckpt files in {weights_directory}')
|
||||||
|
ckpt_files = dict()
|
||||||
|
for root, dirs, files in os.walk(weights_directory):
|
||||||
|
for f in files:
|
||||||
|
if not f.endswith('.ckpt'):
|
||||||
|
continue
|
||||||
|
basename = Path(f).stem
|
||||||
|
dest = Path(dest_directory,basename)
|
||||||
|
if not dest.exists():
|
||||||
|
ckpt_files[Path(root,f)]=dest
|
||||||
|
|
||||||
|
if len(ckpt_files)==0:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f'>> New .ckpt file(s) found in {weights_directory}. Optimizing and importing...')
|
||||||
|
for ckpt in ckpt_files:
|
||||||
|
self.convert_and_import(ckpt, ckpt_files[ckpt])
|
||||||
|
self.commit(conf_path)
|
||||||
|
|
||||||
|
def convert_and_import(self,
|
||||||
|
ckpt_path:Path,
|
||||||
|
diffuser_path:Path,
|
||||||
|
model_name=None,
|
||||||
|
model_description=None,
|
||||||
|
commit_to_conf:Path=None,
|
||||||
|
)->dict:
|
||||||
|
'''
|
||||||
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
|
into models.yaml.
|
||||||
|
'''
|
||||||
|
new_config = None
|
||||||
|
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||||
|
import transformers
|
||||||
|
if diffuser_path.exists():
|
||||||
|
print(f'ERROR: The path {str(diffuser_path)} already exists. Please move or remove it and try again.')
|
||||||
|
return
|
||||||
|
|
||||||
|
model_name = model_name or diffuser_path.name
|
||||||
|
model_description = model_description or 'Optimized version of {model_name}'
|
||||||
|
print(f'>> {model_name}: optimizing (30-60s).')
|
||||||
|
try:
|
||||||
|
verbosity =transformers.logging.get_verbosity()
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
convert_ckpt_to_diffuser(ckpt_path, diffuser_path,extract_ema=True)
|
||||||
|
transformers.logging.set_verbosity(verbosity)
|
||||||
|
print(f'>> Success. Optimized model is now located at {str(diffuser_path)}')
|
||||||
|
print(f'>> Writing new config file entry for {model_name}...',end='')
|
||||||
|
new_config = dict(
|
||||||
|
path=str(diffuser_path),
|
||||||
|
description=model_description,
|
||||||
|
format='diffusers',
|
||||||
|
)
|
||||||
|
self.del_model(model_name)
|
||||||
|
self.add_model(model_name, new_config, True)
|
||||||
|
if commit_to_conf:
|
||||||
|
self.commit(commit_to_conf)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'** Conversion failed: {str(e)}')
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print('done.')
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
def del_config(self, model_name:str, gen, opt, completer):
|
||||||
|
current_model = gen.model_name
|
||||||
|
if model_name == current_model:
|
||||||
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
|
return
|
||||||
|
gen.model_manager.del_model(model_name)
|
||||||
|
gen.model_manager.commit(opt.conf)
|
||||||
|
print(f'** {model_name} deleted')
|
||||||
|
completer.del_model(model_name)
|
||||||
|
|
||||||
|
def search_models(self, search_folder):
|
||||||
|
print(f'>> Finding Models In: {search_folder}')
|
||||||
|
models_folder_ckpt = Path(search_folder).glob('**/*.ckpt')
|
||||||
|
models_folder_safetensors = Path(search_folder).glob('**/*.safetensors')
|
||||||
|
|
||||||
|
ckpt_files = [x for x in models_folder_ckpt if x.is_file()]
|
||||||
|
safetensor_files = [x for x in models_folder_safetensors if x.is_file]
|
||||||
|
|
||||||
|
files = ckpt_files + safetensor_files
|
||||||
|
|
||||||
|
found_models = []
|
||||||
|
for file in files:
|
||||||
|
found_models.append({
|
||||||
|
'name': file.stem,
|
||||||
|
'location': str(file.resolve()).replace('\\', '/')
|
||||||
|
})
|
||||||
|
|
||||||
|
return search_folder, found_models
|
||||||
|
|
||||||
|
def _make_cache_room(self) -> None:
|
||||||
|
num_loaded_models = len(self.models)
|
||||||
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
|
least_recent_model = self._pop_oldest_model()
|
||||||
|
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||||
|
if least_recent_model is not None:
|
||||||
|
del self.models[least_recent_model]
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def print_vram_usage(self) -> None:
|
||||||
|
if self._has_cuda:
|
||||||
|
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
|
def commit(self,config_file_path:str) -> None:
|
||||||
|
'''
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
'''
|
||||||
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
|
if not os.path.isabs(config_file_path):
|
||||||
|
config_file_path = os.path.normpath(os.path.join(Globals.root,config_file_path))
|
||||||
|
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
|
||||||
|
with open(tmpfile, 'w', encoding="utf-8") as outfile:
|
||||||
|
outfile.write(self.preamble())
|
||||||
|
outfile.write(yaml_str)
|
||||||
|
os.replace(tmpfile,config_file_path)
|
||||||
|
|
||||||
|
def preamble(self) -> str:
|
||||||
|
'''
|
||||||
|
Returns the preamble for the config file.
|
||||||
|
'''
|
||||||
|
return textwrap.dedent('''\
|
||||||
|
# This file describes the alternative machine learning models
|
||||||
|
# available to InvokeAI script.
|
||||||
|
#
|
||||||
|
# To add a new model, follow the examples below. Each
|
||||||
|
# model requires a model config file, a weights file,
|
||||||
|
# and the width and height of the images it
|
||||||
|
# was trained on.
|
||||||
|
''')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def migrate_models(cls):
|
||||||
|
'''
|
||||||
|
Migrate the ~/invokeai/models directory from the legacy format used through 2.2.5
|
||||||
|
to the 2.3.0 "diffusers" version. This should be a one-time operation, called at
|
||||||
|
script startup time.
|
||||||
|
'''
|
||||||
|
# Three transformer models to check: bert, clip and safety checker
|
||||||
|
legacy_locations = [
|
||||||
|
Path('CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker'),
|
||||||
|
Path('bert-base-uncased/models--bert-base-uncased'),
|
||||||
|
Path('openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14')
|
||||||
|
]
|
||||||
|
models_dir = Path(Globals.root,'models')
|
||||||
|
legacy_layout = False
|
||||||
|
for model in legacy_locations:
|
||||||
|
legacy_layout = legacy_layout or Path(models_dir,model).exists()
|
||||||
|
if not legacy_layout:
|
||||||
|
return
|
||||||
|
|
||||||
|
print('** Legacy version <= 2.2.5 model directory layout detected. Reorganizing.')
|
||||||
|
print('** This is a quick one-time operation.')
|
||||||
|
from shutil import move
|
||||||
|
|
||||||
|
# transformer files get moved into the hub directory
|
||||||
|
hub = models_dir / 'hub'
|
||||||
|
os.makedirs(hub, exist_ok=True)
|
||||||
|
for model in legacy_locations:
|
||||||
|
source = models_dir /model
|
||||||
|
if source.exists():
|
||||||
|
print(f'DEBUG: Moving {models_dir / model} into hub')
|
||||||
|
move(models_dir / model, hub)
|
||||||
|
|
||||||
|
# anything else gets moved into the diffusers directory
|
||||||
|
diffusers = models_dir / 'diffusers'
|
||||||
|
os.makedirs(diffusers, exist_ok=True)
|
||||||
|
for root, dirs, _ in os.walk(models_dir, topdown=False):
|
||||||
|
for dir in dirs:
|
||||||
|
full_path = Path(root,dir)
|
||||||
|
if full_path.is_relative_to(hub) or full_path.is_relative_to(diffusers):
|
||||||
|
continue
|
||||||
|
if Path(dir).match('models--*--*'):
|
||||||
|
move(full_path,diffusers)
|
||||||
|
|
||||||
|
# now clean up by removing any empty directories
|
||||||
|
empty = [root for root, dirs, files, in os.walk(models_dir) if not len(dirs) and not len(files)]
|
||||||
|
for d in empty:
|
||||||
|
os.rmdir(d)
|
||||||
|
print('** Migration is done. Continuing...')
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(self, source:Union[str,Path], dest_directory:str)->Path:
|
||||||
|
resolved_path = None
|
||||||
|
if source.startswith(('http:','https:','ftp:')):
|
||||||
|
basename = os.path.basename(source)
|
||||||
|
if not os.path.isabs(dest_directory):
|
||||||
|
dest_directory = os.path.join(Globals.root,dest_directory)
|
||||||
|
dest = os.path.join(dest_directory,basename)
|
||||||
|
if download_with_progress_bar(source,dest):
|
||||||
|
resolved_path = Path(dest)
|
||||||
|
else:
|
||||||
|
if not os.path.isabs(source):
|
||||||
|
source = os.path.join(Globals.root,source)
|
||||||
|
resolved_path = Path(source)
|
||||||
|
return resolved_path
|
||||||
|
|
||||||
|
def _invalidate_cached_model(self,model_name:str) -> None:
|
||||||
|
self.offload_model(model_name)
|
||||||
|
if model_name in self.stack:
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
self.models.pop(model_name,None)
|
||||||
|
|
||||||
|
def _model_to_cpu(self,model):
|
||||||
|
if self.device == 'cpu':
|
||||||
|
return model
|
||||||
|
|
||||||
|
# diffusers really really doesn't like us moving a float16 model onto CPU
|
||||||
|
import logging
|
||||||
|
logging.getLogger('diffusers.pipeline_utils').setLevel(logging.CRITICAL)
|
||||||
|
model.cond_stage_model.device = 'cpu'
|
||||||
|
model.to('cpu')
|
||||||
|
logging.getLogger('pipeline_utils').setLevel(logging.INFO)
|
||||||
|
|
||||||
|
for submodel in ('first_stage_model','cond_stage_model','model'):
|
||||||
|
try:
|
||||||
|
getattr(model,submodel).to('cpu')
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _model_from_cpu(self,model):
|
||||||
|
if self.device == 'cpu':
|
||||||
|
return model
|
||||||
|
|
||||||
|
model.to(self.device)
|
||||||
|
model.cond_stage_model.device = self.device
|
||||||
|
|
||||||
|
for submodel in ('first_stage_model','cond_stage_model','model'):
|
||||||
|
try:
|
||||||
|
getattr(model,submodel).to(self.device)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _pop_oldest_model(self):
|
||||||
|
'''
|
||||||
|
Remove the first element of the FIFO, which ought
|
||||||
|
to be the least recently accessed model. Do not
|
||||||
|
pop the last one, because it is in active use!
|
||||||
|
'''
|
||||||
|
return self.stack.pop(0)
|
||||||
|
|
||||||
|
def _push_newest_model(self,model_name:str) -> None:
|
||||||
|
'''
|
||||||
|
Maintain a simple FIFO. First element is always the
|
||||||
|
least recent, and last element is always the most recent.
|
||||||
|
'''
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
self.stack.remove(model_name)
|
||||||
|
self.stack.append(model_name)
|
||||||
|
|
||||||
|
def _has_cuda(self) -> bool:
|
||||||
|
return self.device.type == 'cuda'
|
||||||
|
|
||||||
|
def _diffuser_sha256(self,name_or_path:Union[str, Path])->Union[str,bytes]:
|
||||||
|
path = None
|
||||||
|
if isinstance(name_or_path,Path):
|
||||||
|
path = name_or_path
|
||||||
|
else:
|
||||||
|
owner,repo = name_or_path.split('/')
|
||||||
|
path = Path(global_cache_dir('diffusers') / f'models--{owner}--{repo}')
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
hashpath = path / 'checksum.sha256'
|
||||||
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||||
|
with open(hashpath) as f:
|
||||||
|
hash = f.read()
|
||||||
|
return hash
|
||||||
|
print(' | Calculating sha256 hash of model files')
|
||||||
|
tic = time.time()
|
||||||
|
sha = hashlib.sha256()
|
||||||
|
count = 0
|
||||||
|
for root, dirs, files in os.walk(path, followlinks=False):
|
||||||
|
for name in files:
|
||||||
|
count += 1
|
||||||
|
with open(os.path.join(root,name),'rb') as f:
|
||||||
|
sha.update(f.read())
|
||||||
|
hash = sha.hexdigest()
|
||||||
|
toc = time.time()
|
||||||
|
print(f' | sha256 = {hash} ({count} files hashed in','%4.2fs)' % (toc - tic))
|
||||||
|
with open(hashpath,'w') as f:
|
||||||
|
f.write(hash)
|
||||||
|
return hash
|
||||||
|
|
||||||
|
def _cached_sha256(self,path,data) -> Union[str, bytes]:
|
||||||
|
dirname = os.path.dirname(path)
|
||||||
|
basename = os.path.basename(path)
|
||||||
|
base, _ = os.path.splitext(basename)
|
||||||
|
hashpath = os.path.join(dirname,base+'.sha256')
|
||||||
|
|
||||||
|
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
|
||||||
|
with open(hashpath) as f:
|
||||||
|
hash = f.read()
|
||||||
|
return hash
|
||||||
|
|
||||||
|
print(' | Calculating sha256 hash of weights file')
|
||||||
|
tic = time.time()
|
||||||
|
sha = hashlib.sha256()
|
||||||
|
sha.update(data)
|
||||||
|
hash = sha.hexdigest()
|
||||||
|
toc = time.time()
|
||||||
|
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
|
||||||
|
|
||||||
|
with open(hashpath,'w') as f:
|
||||||
|
f.write(hash)
|
||||||
|
return hash
|
||||||
|
|
||||||
|
def _load_vae(self, vae_config):
|
||||||
|
vae_args = {}
|
||||||
|
name_or_path = self.model_name_or_path(vae_config)
|
||||||
|
using_fp16 = self.precision == 'float16'
|
||||||
|
|
||||||
|
vae_args.update(
|
||||||
|
cache_dir=global_cache_dir('diffusers'),
|
||||||
|
local_files_only=not Globals.internet_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f' | Loading diffusers VAE from {name_or_path}')
|
||||||
|
if using_fp16:
|
||||||
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
|
fp_args_list = [{'revision':'fp16'},{}]
|
||||||
|
else:
|
||||||
|
print(' | Using more accurate float32 precision')
|
||||||
|
fp_args_list = [{}]
|
||||||
|
|
||||||
|
vae = None
|
||||||
|
deferred_error = None
|
||||||
|
|
||||||
|
# A VAE may be in a subfolder of a model's repository.
|
||||||
|
if 'subfolder' in vae_config:
|
||||||
|
vae_args['subfolder'] = vae_config['subfolder']
|
||||||
|
|
||||||
|
for fp_args in fp_args_list:
|
||||||
|
# At some point we might need to be able to use different classes here? But for now I think
|
||||||
|
# all Stable Diffusion VAE are AutoencoderKL.
|
||||||
|
try:
|
||||||
|
vae = AutoencoderKL.from_pretrained(name_or_path, **vae_args, **fp_args)
|
||||||
|
except OSError as e:
|
||||||
|
if str(e).startswith('fp16 is not a valid'):
|
||||||
|
print(' | Half-precision version of model not available; fetching full-precision instead')
|
||||||
|
else:
|
||||||
|
deferred_error = e
|
||||||
|
if vae:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not vae and deferred_error:
|
||||||
|
print(f'** Could not load VAE {name_or_path}: {str(deferred_error)}')
|
||||||
|
|
||||||
|
return vae
|
@ -12,7 +12,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import atexit
|
import atexit
|
||||||
from ldm.invoke.args import Args
|
from ldm.invoke.args import Args
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
# ---------------readline utilities---------------------
|
# ---------------readline utilities---------------------
|
||||||
@ -24,7 +24,7 @@ except (ImportError,ModuleNotFoundError) as e:
|
|||||||
readline_available = False
|
readline_available = False
|
||||||
|
|
||||||
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
|
||||||
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
|
WEIGHT_EXTENSIONS = ('.ckpt','.vae','.safetensors')
|
||||||
TEXT_EXTENSIONS = ('.txt','.TXT')
|
TEXT_EXTENSIONS = ('.txt','.TXT')
|
||||||
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
CONFIG_EXTENSIONS = ('.yaml','.yml')
|
||||||
COMMANDS = (
|
COMMANDS = (
|
||||||
@ -59,7 +59,7 @@ COMMANDS = (
|
|||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
'--text_mask','-tm',
|
'--text_mask','-tm',
|
||||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||||
'!models','!switch','!import_model','!edit_model','!del_model',
|
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||||
'!mask',
|
'!mask',
|
||||||
)
|
)
|
||||||
MODEL_COMMANDS = (
|
MODEL_COMMANDS = (
|
||||||
@ -67,8 +67,12 @@ MODEL_COMMANDS = (
|
|||||||
'!edit_model',
|
'!edit_model',
|
||||||
'!del_model',
|
'!del_model',
|
||||||
)
|
)
|
||||||
|
CKPT_MODEL_COMMANDS = (
|
||||||
|
'!optimize_model',
|
||||||
|
)
|
||||||
WEIGHT_COMMANDS = (
|
WEIGHT_COMMANDS = (
|
||||||
'!import_model',
|
'!import_model',
|
||||||
|
'!convert_model',
|
||||||
)
|
)
|
||||||
IMG_PATH_COMMANDS = (
|
IMG_PATH_COMMANDS = (
|
||||||
'--outdir[=\s]',
|
'--outdir[=\s]',
|
||||||
@ -91,9 +95,9 @@ weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
|
|||||||
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
|
||||||
|
|
||||||
class Completer(object):
|
class Completer(object):
|
||||||
def __init__(self, options, models=[]):
|
def __init__(self, options, models={}):
|
||||||
self.options = sorted(options)
|
self.options = sorted(options)
|
||||||
self.models = sorted(models)
|
self.models = models
|
||||||
self.seeds = set()
|
self.seeds = set()
|
||||||
self.matches = list()
|
self.matches = list()
|
||||||
self.default_dir = None
|
self.default_dir = None
|
||||||
@ -134,6 +138,10 @@ class Completer(object):
|
|||||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||||
self.matches= self._model_completions(text, state)
|
self.matches= self._model_completions(text, state)
|
||||||
|
|
||||||
|
# looking for a ckpt model
|
||||||
|
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||||
|
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||||
|
|
||||||
elif re.search(weight_regexp,buffer):
|
elif re.search(weight_regexp,buffer):
|
||||||
self.matches = self._path_completions(
|
self.matches = self._path_completions(
|
||||||
text,
|
text,
|
||||||
@ -242,17 +250,11 @@ class Completer(object):
|
|||||||
self.linebuffer = line
|
self.linebuffer = line
|
||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
|
|
||||||
def add_model(self,model_name:str)->None:
|
def update_models(self,models:dict)->None:
|
||||||
'''
|
'''
|
||||||
add a model name to the completion list
|
update our list of models
|
||||||
'''
|
'''
|
||||||
self.models.append(model_name)
|
self.models = models
|
||||||
|
|
||||||
def del_model(self,model_name:str)->None:
|
|
||||||
'''
|
|
||||||
removes a model name from the completion list
|
|
||||||
'''
|
|
||||||
self.models.remove(model_name)
|
|
||||||
|
|
||||||
def _seed_completions(self, text, state):
|
def _seed_completions(self, text, state):
|
||||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||||
@ -278,7 +280,7 @@ class Completer(object):
|
|||||||
def _concept_completions(self, text, state):
|
def _concept_completions(self, text, state):
|
||||||
if self.concepts is None:
|
if self.concepts is None:
|
||||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||||
self.concepts = Concepts()
|
self.concepts = HuggingFaceConceptsLibrary()
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
else:
|
else:
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
@ -294,7 +296,7 @@ class Completer(object):
|
|||||||
matches.sort()
|
matches.sort()
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def _model_completions(self, text, state):
|
def _model_completions(self, text, state, ckpt_only=False):
|
||||||
m = re.search('(!switch\s+)(\w*)',text)
|
m = re.search('(!switch\s+)(\w*)',text)
|
||||||
if m:
|
if m:
|
||||||
switch = m.groups()[0]
|
switch = m.groups()[0]
|
||||||
@ -304,6 +306,11 @@ class Completer(object):
|
|||||||
partial = text
|
partial = text
|
||||||
matches = list()
|
matches = list()
|
||||||
for s in self.models:
|
for s in self.models:
|
||||||
|
format = self.models[s]['format']
|
||||||
|
if format == 'vae':
|
||||||
|
continue
|
||||||
|
if ckpt_only and format != 'ckpt':
|
||||||
|
continue
|
||||||
if s.startswith(partial):
|
if s.startswith(partial):
|
||||||
matches.append(switch+s)
|
matches.append(switch+s)
|
||||||
matches.sort()
|
matches.sort()
|
||||||
|
@ -12,6 +12,7 @@ def configure_model_padding(model, seamless, seamless_axes):
|
|||||||
"""
|
"""
|
||||||
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
|
||||||
"""
|
"""
|
||||||
|
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
if seamless:
|
if seamless:
|
||||||
|
799
ldm/invoke/textual_inversion_training.py
Normal file
799
ldm/invoke/textual_inversion_training.py
Normal file
@ -0,0 +1,799 @@
|
|||||||
|
# This code was copied from
|
||||||
|
# https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||||
|
# on January 2, 2023
|
||||||
|
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from argparse import Namespace
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import diffusers
|
||||||
|
import PIL
|
||||||
|
import transformers
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.utils import check_min_version
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
|
||||||
|
# invokeai stuff
|
||||||
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||||
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
|
PIL_INTERPOLATION = {
|
||||||
|
"linear": PIL.Image.Resampling.BILINEAR,
|
||||||
|
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||||
|
"nearest": PIL.Image.Resampling.NEAREST,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
PIL_INTERPOLATION = {
|
||||||
|
"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
"nearest": PIL.Image.NEAREST,
|
||||||
|
}
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
check_min_version("0.10.0.dev0")
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
|
||||||
|
logger.info("Saving embeddings")
|
||||||
|
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||||
|
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
||||||
|
torch.save(learned_embeds_dict, save_path)
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_steps",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Save learned_embeds.bin every X updates steps.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--root_dir','--root',
|
||||||
|
type=Path,
|
||||||
|
default=Globals.root,
|
||||||
|
help="Path to the invokeai runtime directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only_save_embeds",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Save only the embeddings for the new concept.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_data_dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="A folder containing the training data."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--placeholder_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="A token to use as a placeholder for the concept.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initializer_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="A token to use as initializer word."
|
||||||
|
)
|
||||||
|
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||||
|
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=Path,
|
||||||
|
default=f'{Globals.root}/text-inversion-model',
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--resolution",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help=(
|
||||||
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||||
|
" resolution"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=5000,
|
||||||
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=1e-4,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale_lr",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler",
|
||||||
|
type=str,
|
||||||
|
default="constant",
|
||||||
|
help=(
|
||||||
|
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||||
|
' "constant", "constant_with_warmup"]'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||||
|
)
|
||||||
|
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||||
|
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||||
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||||
|
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
|
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--logging_dir",
|
||||||
|
type=Path,
|
||||||
|
default="logs",
|
||||||
|
help=(
|
||||||
|
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||||
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mixed_precision",
|
||||||
|
type=str,
|
||||||
|
default="no",
|
||||||
|
choices=["no", "fp16", "bf16"],
|
||||||
|
help=(
|
||||||
|
"Whether to use mixed precision. Choose"
|
||||||
|
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||||
|
"and an Nvidia Ampere GPU."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow_tf32",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||||
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="tensorboard",
|
||||||
|
help=(
|
||||||
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||||
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpointing_steps",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help=(
|
||||||
|
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||||
|
" training using `--resume_from_checkpoint`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_checkpoint",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||||
|
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
"a photo of a {}",
|
||||||
|
"a rendering of a {}",
|
||||||
|
"a cropped photo of the {}",
|
||||||
|
"the photo of a {}",
|
||||||
|
"a photo of a clean {}",
|
||||||
|
"a photo of a dirty {}",
|
||||||
|
"a dark photo of the {}",
|
||||||
|
"a photo of my {}",
|
||||||
|
"a photo of the cool {}",
|
||||||
|
"a close-up photo of a {}",
|
||||||
|
"a bright photo of the {}",
|
||||||
|
"a cropped photo of a {}",
|
||||||
|
"a photo of the {}",
|
||||||
|
"a good photo of the {}",
|
||||||
|
"a photo of one {}",
|
||||||
|
"a close-up photo of the {}",
|
||||||
|
"a rendition of the {}",
|
||||||
|
"a photo of the clean {}",
|
||||||
|
"a rendition of a {}",
|
||||||
|
"a photo of a nice {}",
|
||||||
|
"a good photo of a {}",
|
||||||
|
"a photo of the nice {}",
|
||||||
|
"a photo of the small {}",
|
||||||
|
"a photo of the weird {}",
|
||||||
|
"a photo of the large {}",
|
||||||
|
"a photo of a cool {}",
|
||||||
|
"a photo of a small {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_style_templates_small = [
|
||||||
|
"a painting in the style of {}",
|
||||||
|
"a rendering in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"the painting in the style of {}",
|
||||||
|
"a clean painting in the style of {}",
|
||||||
|
"a dirty painting in the style of {}",
|
||||||
|
"a dark painting in the style of {}",
|
||||||
|
"a picture in the style of {}",
|
||||||
|
"a cool painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a bright painting in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"a good painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a rendition in the style of {}",
|
||||||
|
"a nice painting in the style of {}",
|
||||||
|
"a small painting in the style of {}",
|
||||||
|
"a weird painting in the style of {}",
|
||||||
|
"a large painting in the style of {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root,
|
||||||
|
tokenizer,
|
||||||
|
learnable_property="object", # [object, style]
|
||||||
|
size=512,
|
||||||
|
repeats=100,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5,
|
||||||
|
set="train",
|
||||||
|
placeholder_token="*",
|
||||||
|
center_crop=False,
|
||||||
|
):
|
||||||
|
self.data_root = data_root
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.learnable_property = learnable_property
|
||||||
|
self.size = size
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
self.center_crop = center_crop
|
||||||
|
self.flip_p = flip_p
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||||
|
|
||||||
|
self.num_images = len(self.image_paths)
|
||||||
|
self._length = self.num_images
|
||||||
|
|
||||||
|
if set == "train":
|
||||||
|
self._length = self.num_images * repeats
|
||||||
|
|
||||||
|
self.interpolation = {
|
||||||
|
"linear": PIL_INTERPOLATION["linear"],
|
||||||
|
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||||
|
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||||
|
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||||
|
}[interpolation]
|
||||||
|
|
||||||
|
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||||
|
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = {}
|
||||||
|
image = Image.open(self.image_paths[i % self.num_images])
|
||||||
|
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
placeholder_string = self.placeholder_token
|
||||||
|
text = random.choice(self.templates).format(placeholder_string)
|
||||||
|
|
||||||
|
example["input_ids"] = self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids[0]
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = (
|
||||||
|
img.shape[0],
|
||||||
|
img.shape[1],
|
||||||
|
)
|
||||||
|
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip_transform(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||||
|
if token is None:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if organization is None:
|
||||||
|
username = whoami(token)["name"]
|
||||||
|
return f"{username}/{model_id}"
|
||||||
|
else:
|
||||||
|
return f"{organization}/{model_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def do_textual_inversion_training(
|
||||||
|
model:str,
|
||||||
|
train_data_dir:Path,
|
||||||
|
output_dir:Path,
|
||||||
|
placeholder_token:str,
|
||||||
|
initializer_token:str,
|
||||||
|
save_steps:int=500,
|
||||||
|
only_save_embeds:bool=False,
|
||||||
|
revision:str=None,
|
||||||
|
tokenizer_name:str=None,
|
||||||
|
learnable_property:str='object',
|
||||||
|
repeats:int=100,
|
||||||
|
seed:int=None,
|
||||||
|
resolution:int=512,
|
||||||
|
center_crop:bool=False,
|
||||||
|
train_batch_size:int=16,
|
||||||
|
num_train_epochs:int=100,
|
||||||
|
max_train_steps:int=5000,
|
||||||
|
gradient_accumulation_steps:int=1,
|
||||||
|
gradient_checkpointing:bool=False,
|
||||||
|
learning_rate:float=1e-4,
|
||||||
|
scale_lr:bool=True,
|
||||||
|
lr_scheduler:str='constant',
|
||||||
|
lr_warmup_steps:int=500,
|
||||||
|
adam_beta1:float=0.9,
|
||||||
|
adam_beta2:float=0.999,
|
||||||
|
adam_weight_decay:float=1e-02,
|
||||||
|
adam_epsilon:float=1e-08,
|
||||||
|
push_to_hub:bool=False,
|
||||||
|
hub_token:str=None,
|
||||||
|
logging_dir:Path=Path('logs'),
|
||||||
|
mixed_precision:str='fp16',
|
||||||
|
allow_tf32:bool=False,
|
||||||
|
report_to:str='tensorboard',
|
||||||
|
local_rank:int=-1,
|
||||||
|
checkpointing_steps:int=500,
|
||||||
|
resume_from_checkpoint:Path=None,
|
||||||
|
enable_xformers_memory_efficient_attention:bool=False,
|
||||||
|
root_dir:Path=None
|
||||||
|
):
|
||||||
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||||
|
if env_local_rank != -1 and env_local_rank != local_rank:
|
||||||
|
local_rank = env_local_rank
|
||||||
|
|
||||||
|
# setting up things the way invokeai expects them
|
||||||
|
if not os.path.isabs(output_dir):
|
||||||
|
output_dir = os.path.join(Globals.root,output_dir)
|
||||||
|
|
||||||
|
logging_dir = output_dir / logging_dir
|
||||||
|
|
||||||
|
accelerator = Accelerator(
|
||||||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
log_with=report_to,
|
||||||
|
logging_dir=logging_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_warning()
|
||||||
|
diffusers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
diffusers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
|
if seed is not None:
|
||||||
|
set_seed(seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if push_to_hub:
|
||||||
|
if hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(Path(output_dir).name, token=hub_token)
|
||||||
|
else:
|
||||||
|
repo_name = hub_model_id
|
||||||
|
repo = Repository(output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
|
with open(os.path.join(output_dir, ".gitignore"), "w+") as gitignore:
|
||||||
|
if "step_*" not in gitignore:
|
||||||
|
gitignore.write("step_*\n")
|
||||||
|
if "epoch_*" not in gitignore:
|
||||||
|
gitignore.write("epoch_*\n")
|
||||||
|
elif output_dir is not None:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
models_conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||||
|
model_conf = models_conf.get(model,None)
|
||||||
|
assert model_conf is not None,f'Unknown model: {model}'
|
||||||
|
assert model_conf.get('format','diffusers')=='diffusers', "This script only works with models of type 'diffusers'"
|
||||||
|
pretrained_model_name_or_path = model_conf.get('repo_id',None) or Path(model_conf.get('path'))
|
||||||
|
assert pretrained_model_name_or_path, f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||||
|
pipeline_args = dict(cache_dir=global_cache_dir('diffusers'))
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
if tokenizer_name:
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name,cache_dir=global_cache_dir('transformers'))
|
||||||
|
else:
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args)
|
||||||
|
|
||||||
|
# Load scheduler and models
|
||||||
|
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args)
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, **pipeline_args
|
||||||
|
)
|
||||||
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision, **pipeline_args)
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, subfolder="unet", revision=revision, **pipeline_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the placeholder token in tokenizer
|
||||||
|
num_added_tokens = tokenizer.add_tokens(placeholder_token)
|
||||||
|
if num_added_tokens == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
||||||
|
" `placeholder_token` that is not already in the tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the initializer_token, placeholder_token to ids
|
||||||
|
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
||||||
|
# Check if initializer_token is a single token or a sequence of tokens
|
||||||
|
if len(token_ids) > 1:
|
||||||
|
raise ValueError(f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}")
|
||||||
|
|
||||||
|
initializer_token_id = token_ids[0]
|
||||||
|
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
|
||||||
|
|
||||||
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||||
|
|
||||||
|
# Freeze vae and unet
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
unet.requires_grad_(False)
|
||||||
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
|
text_encoder.text_model.encoder.requires_grad_(False)
|
||||||
|
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||||
|
|
||||||
|
if gradient_checkpointing:
|
||||||
|
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||||
|
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||||
|
unet.train()
|
||||||
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
if enable_xformers_memory_efficient_attention:
|
||||||
|
if is_xformers_available():
|
||||||
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
|
else:
|
||||||
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||||
|
|
||||||
|
# Enable TF32 for faster training on Ampere GPUs,
|
||||||
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||||
|
if allow_tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
if scale_lr:
|
||||||
|
learning_rate = (
|
||||||
|
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the optimizer
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||||
|
lr=learning_rate,
|
||||||
|
betas=(adam_beta1, adam_beta2),
|
||||||
|
weight_decay=adam_weight_decay,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset and DataLoaders creation:
|
||||||
|
train_dataset = TextualInversionDataset(
|
||||||
|
data_root=train_data_dir,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
size=resolution,
|
||||||
|
placeholder_token=placeholder_token,
|
||||||
|
repeats=repeats,
|
||||||
|
learnable_property=learnable_property,
|
||||||
|
center_crop=center_crop,
|
||||||
|
set="train",
|
||||||
|
)
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Scheduler and math around the number of training steps.
|
||||||
|
overrode_max_train_steps = False
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||||
|
if max_train_steps is None:
|
||||||
|
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||||
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||||
|
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||||
|
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||||
|
weight_dtype = torch.float32
|
||||||
|
if accelerator.mixed_precision == "fp16":
|
||||||
|
weight_dtype = torch.float16
|
||||||
|
elif accelerator.mixed_precision == "bf16":
|
||||||
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Move vae and unet to device and cast to weight_dtype
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||||
|
if overrode_max_train_steps:
|
||||||
|
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||||
|
# Afterwards we recalculate our number of training epochs
|
||||||
|
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# We need to initialize the trackers we use, and also store our configuration.
|
||||||
|
# The trackers initializes automatically on the main process.
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
params = locals()
|
||||||
|
for k in params: # init_trackers() doesn't like objects
|
||||||
|
params[k] = str(params[k]) if isinstance(params[k],object) else params[k]
|
||||||
|
accelerator.init_trackers("textual_inversion", config=params)
|
||||||
|
|
||||||
|
# Train!
|
||||||
|
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(train_dataset)}")
|
||||||
|
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||||
|
global_step = 0
|
||||||
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
|
if resume_from_checkpoint:
|
||||||
|
if resume_from_checkpoint != "latest":
|
||||||
|
path = os.path.basename(resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
# Get the most recent checkpoint
|
||||||
|
dirs = os.listdir(output_dir)
|
||||||
|
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||||
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||||
|
path = dirs[-1]
|
||||||
|
accelerator.print(f"Resuming from checkpoint {path}")
|
||||||
|
accelerator.load_state(os.path.join(output_dir, path))
|
||||||
|
global_step = int(path.split("-")[1])
|
||||||
|
|
||||||
|
resume_global_step = global_step * gradient_accumulation_steps
|
||||||
|
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||||
|
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||||
|
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
progress_bar.set_description("Steps")
|
||||||
|
|
||||||
|
# keep original embeddings as reference
|
||||||
|
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||||
|
|
||||||
|
for epoch in range(first_epoch, num_train_epochs):
|
||||||
|
text_encoder.train()
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
# Skip steps until we reach the resumed step
|
||||||
|
if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||||
|
if step % gradient_accumulation_steps == 0:
|
||||||
|
progress_bar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
with accelerator.accumulate(text_encoder):
|
||||||
|
# Convert images to latent space
|
||||||
|
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||||
|
latents = latents * 0.18215
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
bsz = latents.shape[0]
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
# Get the target for loss depending on the prediction type
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
|
accelerator.backward(loss)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
|
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||||
|
with torch.no_grad():
|
||||||
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||||
|
index_no_updates
|
||||||
|
] = orig_embeds_params[index_no_updates]
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
if global_step % save_steps == 0:
|
||||||
|
save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||||
|
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||||
|
|
||||||
|
if global_step % checkpointing_steps == 0:
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||||
|
accelerator.save_state(save_path)
|
||||||
|
logger.info(f"Saved state to {save_path}")
|
||||||
|
|
||||||
|
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
if global_step >= max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create the pipeline using using the trained modules and save it.
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if push_to_hub and only_save_embeds:
|
||||||
|
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||||
|
save_full_model = True
|
||||||
|
else:
|
||||||
|
save_full_model = not only_save_embeds
|
||||||
|
if save_full_model:
|
||||||
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**pipeline_args,
|
||||||
|
)
|
||||||
|
pipeline.save_pretrained(output_dir)
|
||||||
|
# Save the newly trained embeddings
|
||||||
|
save_path = os.path.join(output_dir, "learned_embeds.bin")
|
||||||
|
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
@ -4,7 +4,9 @@ from typing import Optional, Callable
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
import diffusers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
@ -337,8 +339,8 @@ def setup_cross_attention_control(model, context: Context):
|
|||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
cross_attention_class: type = InvokeAICrossAttentionMixin
|
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||||
# cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||||
isinstance(module, cross_attention_class) and which_attn in name]
|
isinstance(module, cross_attention_class) and which_attn in name]
|
||||||
@ -441,3 +443,19 @@ def get_mem_free_total(device):
|
|||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
InvokeAICrossAttentionMixin.__init__(self)
|
||||||
|
|
||||||
|
def _attention(self, query, key, value, attention_mask=None):
|
||||||
|
#default_result = super()._attention(query, key, value)
|
||||||
|
if attention_mask is not None:
|
||||||
|
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||||
|
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||||
|
|
||||||
|
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
|||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
from ldm.util import (
|
from ldm.util import (
|
||||||
log_txt_as_img,
|
log_txt_as_img,
|
||||||
exists,
|
exists,
|
||||||
@ -678,6 +679,13 @@ class LatentDiffusion(DDPM):
|
|||||||
self.embedding_manager = self.instantiate_embedding_manager(
|
self.embedding_manager = self.instantiate_embedding_manager(
|
||||||
personalization_config, self.cond_stage_model
|
personalization_config, self.cond_stage_model
|
||||||
)
|
)
|
||||||
|
self.textual_inversion_manager = TextualInversionManager(
|
||||||
|
tokenizer = self.cond_stage_model.tokenizer,
|
||||||
|
text_encoder = self.cond_stage_model.transformer,
|
||||||
|
full_precision = True
|
||||||
|
)
|
||||||
|
# this circular component dependency is gross and bad, needs to be rethought
|
||||||
|
self.cond_stage_model.set_textual_inversion_manager(self.textual_inversion_manager)
|
||||||
|
|
||||||
self.emb_ckpt_counter = 0
|
self.emb_ckpt_counter = 0
|
||||||
|
|
||||||
|
@ -209,12 +209,12 @@ class KSampler(Sampler):
|
|||||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||||
|
|
||||||
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
# setup attention maps saving. checks for None are because there are multiple code paths to get here.
|
||||||
attention_maps_saver = None
|
attention_map_saver = None
|
||||||
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
if attention_maps_callback is not None and extra_conditioning_info is not None:
|
||||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||||
attention_map_token_ids = range(1, eos_token_index)
|
attention_map_token_ids = range(1, eos_token_index)
|
||||||
attention_maps_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
attention_map_saver = AttentionMapSaver(token_ids = attention_map_token_ids, latents_shape=x.shape[-2:])
|
||||||
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_maps_saver)
|
model_wrap_cfg.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||||
|
|
||||||
extra_args = {
|
extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
@ -229,8 +229,8 @@ class KSampler(Sampler):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if attention_maps_saver is not None:
|
if attention_map_saver is not None:
|
||||||
attention_maps_callback(attention_maps_saver)
|
attention_maps_callback(attention_map_saver)
|
||||||
return sampling_result
|
return sampling_result
|
||||||
|
|
||||||
# this code will support inpainting if and when ksampler API modified or
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
|
@ -1,14 +1,23 @@
|
|||||||
import traceback
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, CrossAttentionType
|
remove_cross_attention_control, setup_cross_attention_control, Context, get_cross_attention_modules, \
|
||||||
|
CrossAttentionType
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ThresholdSettings:
|
||||||
|
threshold: float
|
||||||
|
warmup: float
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
'''
|
'''
|
||||||
The aim of this component is to provide a single place for code that can be applied identically to
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
@ -18,6 +27,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
* Cross attention control ("prompt2prompt")
|
* Cross attention control ("prompt2prompt")
|
||||||
* Hybrid conditioning (used for inpainting)
|
* Hybrid conditioning (used for inpainting)
|
||||||
'''
|
'''
|
||||||
|
debug_thresholding = False
|
||||||
|
|
||||||
|
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
@ -36,6 +46,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
@ -77,7 +88,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning: Union[torch.Tensor,dict],
|
unconditioning: Union[torch.Tensor,dict],
|
||||||
conditioning: Union[torch.Tensor,dict],
|
conditioning: Union[torch.Tensor,dict],
|
||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int]=None
|
step_index: Optional[int]=None,
|
||||||
|
threshold: Optional[ThresholdSettings]=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: current latents
|
:param x: current latents
|
||||||
@ -86,6 +98,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||||
|
:param threshold: threshold to apply after each step
|
||||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -106,13 +119,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
else:
|
else:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
||||||
|
|
||||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
|
||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
if threshold:
|
||||||
|
combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
|
||||||
|
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
@ -120,8 +133,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice,
|
both_results = self.model_forward_callback(x_twice, sigma_twice, both_conditionings)
|
||||||
both_conditionings).chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
|
if conditioned_next_x.device.type == 'mps':
|
||||||
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
@ -179,6 +195,51 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||||
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
|
return combined_next_x
|
||||||
|
|
||||||
|
def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor:
|
||||||
|
warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf
|
||||||
|
if warmup_scale < 1:
|
||||||
|
# This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395
|
||||||
|
warming_threshold = 1 + (threshold - 1) * warmup_scale
|
||||||
|
current_threshold = np.clip(warming_threshold, 1, threshold)
|
||||||
|
else:
|
||||||
|
current_threshold = threshold
|
||||||
|
|
||||||
|
if current_threshold <= 0:
|
||||||
|
return latents
|
||||||
|
maxval = latents.max().item()
|
||||||
|
minval = latents.min().item()
|
||||||
|
|
||||||
|
scale = 0.7 # default value from #395
|
||||||
|
|
||||||
|
if self.debug_thresholding:
|
||||||
|
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||||
|
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||||
|
print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||||
|
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||||
|
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||||
|
|
||||||
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
if maxval > current_threshold:
|
||||||
|
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||||
|
|
||||||
|
if minval < -current_threshold:
|
||||||
|
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||||
|
|
||||||
|
if self.debug_thresholding:
|
||||||
|
outside = torch.count_nonzero((latents < minval) | (latents > maxval))
|
||||||
|
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||||
|
f" | {outside / latents.numel() * 100:.2f}% values will be clamped")
|
||||||
|
|
||||||
|
return latents.clamp(minval, maxval)
|
||||||
|
|
||||||
def estimate_percent_through(self, step_index, sigma):
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
if step_index is not None and self.cross_attention_control_context is not None:
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
# percent_through will never reach 1.0 (but this is intended)
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
|
@ -162,7 +162,6 @@ def get_mem_free_total(device):
|
|||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
class CrossAttention(nn.Module, InvokeAICrossAttentionMixin):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import os.path
|
import os.path
|
||||||
from cmath import log
|
from cmath import log
|
||||||
import torch
|
import torch
|
||||||
|
from attr import dataclass
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ldm.invoke.concepts_lib import Concepts
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.data.personalized import per_img_token_list
|
from ldm.data.personalized import per_img_token_list
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -14,36 +15,16 @@ from picklescan.scanner import scan_file_path
|
|||||||
PROGRESSIVE_SCALE = 2000
|
PROGRESSIVE_SCALE = 2000
|
||||||
|
|
||||||
|
|
||||||
def get_clip_token_for_string(tokenizer, string):
|
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
|
||||||
batch_encoding = tokenizer(
|
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||||
string,
|
return token_id
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_length=True,
|
|
||||||
return_overflowing_tokens=False,
|
|
||||||
padding='max_length',
|
|
||||||
return_tensors='pt',
|
|
||||||
)
|
|
||||||
tokens = batch_encoding['input_ids']
|
|
||||||
""" assert (
|
|
||||||
torch.count_nonzero(tokens - 49407) == 2
|
|
||||||
), f"String '{string}' maps to more than a single token. Please use another string" """
|
|
||||||
|
|
||||||
return tokens[0, 1]
|
def get_embedding_for_clip_token_id(embedder, token_id):
|
||||||
|
if type(token_id) is not torch.Tensor:
|
||||||
|
token_id = torch.tensor(token_id, dtype=torch.int)
|
||||||
|
return embedder(token_id.unsqueeze(0))[0, 0]
|
||||||
|
|
||||||
|
|
||||||
def get_bert_token_for_string(tokenizer, string):
|
|
||||||
token = tokenizer(string)
|
|
||||||
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
|
||||||
|
|
||||||
token = token[0, 1]
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_for_clip_token(embedder, token):
|
|
||||||
return embedder(token.unsqueeze(0))[0, 0]
|
|
||||||
|
|
||||||
class EmbeddingManager(nn.Module):
|
class EmbeddingManager(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -58,8 +39,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
self.concepts_library=Concepts()
|
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||||
self.concepts_loaded = dict()
|
|
||||||
|
|
||||||
self.string_to_token_dict = {}
|
self.string_to_token_dict = {}
|
||||||
self.string_to_param_dict = nn.ParameterDict()
|
self.string_to_param_dict = nn.ParameterDict()
|
||||||
@ -77,11 +57,11 @@ class EmbeddingManager(nn.Module):
|
|||||||
embedder, 'tokenizer'
|
embedder, 'tokenizer'
|
||||||
): # using Stable Diffusion's CLIP encoder
|
): # using Stable Diffusion's CLIP encoder
|
||||||
self.is_clip = True
|
self.is_clip = True
|
||||||
get_token_for_string = partial(
|
get_token_id_for_string = partial(
|
||||||
get_clip_token_for_string, embedder.tokenizer
|
get_clip_token_id_for_string, embedder.tokenizer
|
||||||
)
|
)
|
||||||
get_embedding_for_tkn = partial(
|
get_embedding_for_tkn_id = partial(
|
||||||
get_embedding_for_clip_token,
|
get_embedding_for_clip_token_id,
|
||||||
embedder.transformer.text_model.embeddings,
|
embedder.transformer.text_model.embeddings,
|
||||||
)
|
)
|
||||||
# per bug report #572
|
# per bug report #572
|
||||||
@ -89,10 +69,10 @@ class EmbeddingManager(nn.Module):
|
|||||||
token_dim = 768
|
token_dim = 768
|
||||||
else: # using LDM's BERT encoder
|
else: # using LDM's BERT encoder
|
||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
get_token_for_string = partial(
|
get_token_id_for_string = partial(
|
||||||
get_bert_token_for_string, embedder.tknz_fn
|
get_bert_token_id_for_string, embedder.tknz_fn
|
||||||
)
|
)
|
||||||
get_embedding_for_tkn = embedder.transformer.token_emb
|
get_embedding_for_tkn_id = embedder.transformer.token_emb
|
||||||
token_dim = 1280
|
token_dim = 1280
|
||||||
|
|
||||||
if per_image_tokens:
|
if per_image_tokens:
|
||||||
@ -100,15 +80,13 @@ class EmbeddingManager(nn.Module):
|
|||||||
|
|
||||||
for idx, placeholder_string in enumerate(placeholder_strings):
|
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||||
|
|
||||||
token = get_token_for_string(placeholder_string)
|
token_id = get_token_id_for_string(placeholder_string)
|
||||||
|
|
||||||
if initializer_words and idx < len(initializer_words):
|
if initializer_words and idx < len(initializer_words):
|
||||||
init_word_token = get_token_for_string(initializer_words[idx])
|
init_word_token_id = get_token_id_for_string(initializer_words[idx])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
init_word_embedding = get_embedding_for_tkn(
|
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
|
||||||
init_word_token.cpu()
|
|
||||||
)
|
|
||||||
|
|
||||||
token_params = torch.nn.Parameter(
|
token_params = torch.nn.Parameter(
|
||||||
init_word_embedding.unsqueeze(0).repeat(
|
init_word_embedding.unsqueeze(0).repeat(
|
||||||
@ -132,7 +110,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.string_to_token_dict[placeholder_string] = token
|
self.string_to_token_dict[placeholder_string] = token_id
|
||||||
self.string_to_param_dict[placeholder_string] = token_params
|
self.string_to_param_dict[placeholder_string] = token_params
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -140,6 +118,8 @@ class EmbeddingManager(nn.Module):
|
|||||||
tokenized_text,
|
tokenized_text,
|
||||||
embedded_text,
|
embedded_text,
|
||||||
):
|
):
|
||||||
|
# torch.save(embedded_text, '/tmp/embedding-manager-uglysonic-pre-rewrite.pt')
|
||||||
|
|
||||||
b, n, device = *tokenized_text.shape, tokenized_text.device
|
b, n, device = *tokenized_text.shape, tokenized_text.device
|
||||||
|
|
||||||
for (
|
for (
|
||||||
@ -164,7 +144,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
placeholder_rows, placeholder_cols = torch.where(
|
placeholder_rows, placeholder_cols = torch.where(
|
||||||
tokenized_text == placeholder_token.to(tokenized_text.device)
|
tokenized_text == placeholder_token
|
||||||
)
|
)
|
||||||
|
|
||||||
if placeholder_rows.nelement() == 0:
|
if placeholder_rows.nelement() == 0:
|
||||||
@ -182,9 +162,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
new_token_row = torch.cat(
|
new_token_row = torch.cat(
|
||||||
[
|
[
|
||||||
tokenized_text[row][:col],
|
tokenized_text[row][:col],
|
||||||
placeholder_token.repeat(num_vectors_for_token).to(
|
torch.tensor([placeholder_token] * num_vectors_for_token, device=device),
|
||||||
device
|
|
||||||
),
|
|
||||||
tokenized_text[row][col + 1 :],
|
tokenized_text[row][col + 1 :],
|
||||||
],
|
],
|
||||||
axis=0,
|
axis=0,
|
||||||
@ -212,22 +190,6 @@ class EmbeddingManager(nn.Module):
|
|||||||
ckpt_path,
|
ckpt_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_concepts(self, concepts:list[str], full=True):
|
|
||||||
bin_files = list()
|
|
||||||
for concept_name in concepts:
|
|
||||||
if concept_name in self.concepts_loaded:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
bin_file = self.concepts_library.get_concept_model_path(concept_name)
|
|
||||||
if not bin_file:
|
|
||||||
continue
|
|
||||||
bin_files.append(bin_file)
|
|
||||||
self.concepts_loaded[concept_name]=True
|
|
||||||
self.load(bin_files, full)
|
|
||||||
|
|
||||||
def list_terms(self) -> list[str]:
|
|
||||||
return self.concepts_loaded.keys()
|
|
||||||
|
|
||||||
def load(self, ckpt_paths, full=True):
|
def load(self, ckpt_paths, full=True):
|
||||||
if len(ckpt_paths) == 0:
|
if len(ckpt_paths) == 0:
|
||||||
return
|
return
|
||||||
@ -282,14 +244,16 @@ class EmbeddingManager(nn.Module):
|
|||||||
if len(embedding.shape) == 1:
|
if len(embedding.shape) == 1:
|
||||||
embedding = embedding.unsqueeze(0)
|
embedding = embedding.unsqueeze(0)
|
||||||
|
|
||||||
|
existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||||
|
if existing_token_id == self.embedder.tokenizer.unk_token_id:
|
||||||
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
|
||||||
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
|
||||||
current_token_count = current_embeddings.num_embeddings
|
current_token_count = current_embeddings.num_embeddings
|
||||||
new_token_count = current_token_count + num_tokens_added
|
new_token_count = current_token_count + num_tokens_added
|
||||||
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
self.embedder.transformer.resize_token_embeddings(new_token_count)
|
||||||
|
|
||||||
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
|
token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
|
||||||
self.string_to_token_dict[token_str] = token
|
self.string_to_token_dict[token_str] = token_id
|
||||||
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
|
||||||
|
|
||||||
def parse_embedding(self, embedding_file: str):
|
def parse_embedding(self, embedding_file: str):
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import os.path
|
import os.path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -8,7 +10,8 @@ from einops import rearrange, repeat
|
|||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import kornia
|
import kornia
|
||||||
from ldm.invoke.devices import choose_torch_device
|
from ldm.invoke.devices import choose_torch_device
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
|
#from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
from ldm.modules.x_transformer import (
|
from ldm.modules.x_transformer import (
|
||||||
Encoder,
|
Encoder,
|
||||||
@ -106,7 +109,7 @@ class BERTTokenizer(AbstractEncoder):
|
|||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
)
|
)
|
||||||
|
|
||||||
cache = os.path.join(Globals.root,'models/bert-base-uncased')
|
cache = global_cache_dir('hub')
|
||||||
try:
|
try:
|
||||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||||
'bert-base-uncased',
|
'bert-base-uncased',
|
||||||
@ -235,26 +238,28 @@ class SpatialRescaler(nn.Module):
|
|||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
|
tokenizer: CLIPTokenizer
|
||||||
|
transformer: CLIPTextModel
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
version='openai/clip-vit-large-patch14',
|
version:str='openai/clip-vit-large-patch14',
|
||||||
device=choose_torch_device(),
|
max_length:int=77,
|
||||||
max_length=77,
|
tokenizer:Optional[CLIPTokenizer]=None,
|
||||||
|
transformer:Optional[CLIPTextModel]=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
cache = os.path.join(Globals.root,'models',version)
|
cache = global_cache_dir('hub')
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
|
||||||
version,
|
version,
|
||||||
cache_dir=cache,
|
cache_dir=cache,
|
||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(
|
self.transformer = transformer or CLIPTextModel.from_pretrained(
|
||||||
version,
|
version,
|
||||||
cache_dir=cache,
|
cache_dir=cache,
|
||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
@ -460,12 +465,25 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
def encode(self, text, **kwargs):
|
def encode(self, text, **kwargs):
|
||||||
return self(text, **kwargs)
|
return self(text, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.transformer.device
|
||||||
|
|
||||||
|
@device.setter
|
||||||
|
def device(self, device):
|
||||||
|
self.transformer.to(device=device)
|
||||||
|
|
||||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||||
|
|
||||||
fragment_weights_key = "fragment_weights"
|
fragment_weights_key = "fragment_weights"
|
||||||
return_tokens_key = "return_tokens"
|
return_tokens_key = "return_tokens"
|
||||||
|
|
||||||
|
def set_textual_inversion_manager(self, manager): #TextualInversionManager):
|
||||||
|
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||||
|
self.textual_inversion_manager = manager
|
||||||
|
|
||||||
def forward(self, text: list, **kwargs):
|
def forward(self, text: list, **kwargs):
|
||||||
|
# TODO all of the weighting and expanding stuff needs be moved out of this class
|
||||||
'''
|
'''
|
||||||
|
|
||||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||||
@ -560,19 +578,43 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
else:
|
else:
|
||||||
return batch_z
|
return batch_z
|
||||||
|
|
||||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||||
tokens = self.tokenizer(
|
"""
|
||||||
|
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
||||||
|
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
||||||
|
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
|
||||||
|
(typically 75 tokens + eos/bos markers).
|
||||||
|
|
||||||
|
:param fragments: The strings to convert.
|
||||||
|
:param include_start_and_end_markers:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
|
||||||
|
token_ids_list = self.tokenizer(
|
||||||
fragments,
|
fragments,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
return_overflowing_tokens=False,
|
return_overflowing_tokens=False,
|
||||||
padding='do_not_pad',
|
padding='do_not_pad',
|
||||||
return_tensors=None, # just give me a list of ints
|
return_tensors=None, # just give me lists of ints
|
||||||
)['input_ids']
|
)['input_ids']
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for token_ids in token_ids_list:
|
||||||
|
# trim eos/bos
|
||||||
|
token_ids = token_ids[1:-1]
|
||||||
|
# pad for textual inversions with vector length >1
|
||||||
|
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
|
||||||
|
# restrict length to max_length-2 (leaving room for bos/eos)
|
||||||
|
token_ids = token_ids[0:self.max_length - 2]
|
||||||
|
# add back eos/bos if requested
|
||||||
if include_start_and_end_markers:
|
if include_start_and_end_markers:
|
||||||
return tokens
|
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
||||||
else:
|
|
||||||
return [x[1:-1] for x in tokens]
|
result.append(token_ids)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -597,56 +639,58 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
if len(fragments) == 0 and len(weights) == 0:
|
if len(fragments) == 0 and len(weights) == 0:
|
||||||
fragments = ['']
|
fragments = ['']
|
||||||
weights = [1]
|
weights = [1]
|
||||||
item_encodings = self.tokenizer(
|
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
||||||
fragments,
|
all_token_ids = []
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
return_overflowing_tokens=True,
|
|
||||||
padding='do_not_pad',
|
|
||||||
return_tensors=None, # just give me a list of ints
|
|
||||||
)['input_ids']
|
|
||||||
all_tokens = []
|
|
||||||
per_token_weights = []
|
per_token_weights = []
|
||||||
#print("all fragments:", fragments, weights)
|
#print("all fragments:", fragments, weights)
|
||||||
for index, fragment in enumerate(item_encodings):
|
for index, fragment in enumerate(per_fragment_token_ids):
|
||||||
weight = weights[index]
|
weight = float(weights[index])
|
||||||
#print("processing fragment", fragment, weight)
|
#print("processing fragment", fragment, weight)
|
||||||
fragment_tokens = item_encodings[index]
|
this_fragment_token_ids = per_fragment_token_ids[index]
|
||||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
#print("fragment", fragment, "processed to", this_fragment_token_ids)
|
||||||
# trim bos and eos markers before appending
|
# append
|
||||||
all_tokens.extend(fragment_tokens[1:-1])
|
all_token_ids += this_fragment_token_ids
|
||||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
# fill out weights tensor with one float per token
|
||||||
|
per_token_weights += [weight] * len(this_fragment_token_ids)
|
||||||
|
|
||||||
if (len(all_tokens) + 2) > self.max_length:
|
# leave room for bos/eos
|
||||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
if len(all_token_ids) > self.max_length - 2:
|
||||||
|
excess_token_count = len(all_token_ids) - self.max_length - 2
|
||||||
|
# TODO build nice description string of how the truncation was applied
|
||||||
|
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||||
|
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||||
all_tokens = all_tokens[:self.max_length - 2]
|
all_token_ids = all_token_ids[0:self.max_length]
|
||||||
per_token_weights = per_token_weights[:self.max_length - 2]
|
per_token_weights = per_token_weights[0:self.max_length]
|
||||||
|
|
||||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||||
# (77 = self.max_length)
|
# (77 = self.max_length)
|
||||||
pad_length = self.max_length - 1 - len(all_tokens)
|
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
pad_length = self.max_length - len(all_token_ids)
|
||||||
per_token_weights.insert(0, 1)
|
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||||
per_token_weights.extend([1] * pad_length)
|
per_token_weights += [1.0] * pad_length
|
||||||
|
|
||||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long).to(self.device)
|
||||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||||
return all_tokens_tensor, per_token_weights_tensor
|
return all_token_ids_tensor, per_token_weights_tensor
|
||||||
|
|
||||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
:param token_ids: A tensor of shape (77) containing token ids (integers)
|
||||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||||
:param kwargs: passed on to self.transformer()
|
:param kwargs: passed on to self.transformer()
|
||||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||||
'''
|
'''
|
||||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
if token_ids.shape != torch.Size([self.max_length]):
|
||||||
|
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||||
|
|
||||||
|
z = self.transformer(input_ids=token_ids.unsqueeze(0), **kwargs)
|
||||||
|
|
||||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||||
|
|
||||||
if weight_delta_from_empty:
|
if weight_delta_from_empty:
|
||||||
@ -660,7 +704,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
z_delta_from_empty = z - empty_z
|
z_delta_from_empty = z - empty_z
|
||||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||||
|
|
||||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
#weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||||
|
|
||||||
#print("using empty-delta method, first 5 rows:")
|
#print("using empty-delta method, first 5 rows:")
|
||||||
|
236
ldm/modules/prompt_to_embeddings_converter.py
Normal file
236
ldm/modules/prompt_to_embeddings_converter.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
tokenizer: CLIPTokenizer, # converts strings to lists of int token ids
|
||||||
|
text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings
|
||||||
|
textual_inversion_manager: TextualInversionManager = None
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.textual_inversion_manager = textual_inversion_manager
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_length(self):
|
||||||
|
return self.tokenizer.model_max_length
|
||||||
|
|
||||||
|
def get_embeddings_for_weighted_prompt_fragments(self,
|
||||||
|
text: list[str],
|
||||||
|
fragment_weights: list[float],
|
||||||
|
should_return_tokens: bool = False,
|
||||||
|
device='cpu'
|
||||||
|
) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
|
||||||
|
:param text: A list of fragments of text to which different weights are to be applied.
|
||||||
|
:param fragment_weights: A batch of lists of weights, one for each entry in `fragments`.
|
||||||
|
:return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1
|
||||||
|
and 1280 for SD2
|
||||||
|
'''
|
||||||
|
if len(text) != len(fragment_weights):
|
||||||
|
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(text)} != {len(fragment_weights)})")
|
||||||
|
|
||||||
|
batch_z = None
|
||||||
|
batch_tokens = None
|
||||||
|
for fragments, weights in zip(text, fragment_weights):
|
||||||
|
|
||||||
|
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||||
|
# applying a multiplier to the CFG scale on a per-token basis).
|
||||||
|
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||||
|
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||||
|
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||||
|
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||||
|
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||||
|
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||||
|
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||||
|
|
||||||
|
# handle weights >=1
|
||||||
|
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device)
|
||||||
|
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights)
|
||||||
|
|
||||||
|
# this is our starting point
|
||||||
|
embeddings = base_embedding.unsqueeze(0)
|
||||||
|
per_embedding_weights = [1.0]
|
||||||
|
|
||||||
|
# now handle weights <1
|
||||||
|
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||||
|
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||||
|
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||||
|
# removed.
|
||||||
|
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||||
|
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||||
|
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||||
|
for index, fragment_weight in enumerate(weights):
|
||||||
|
if fragment_weight < 1:
|
||||||
|
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||||
|
weights_without_this = weights[:index] + weights[index+1:]
|
||||||
|
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device)
|
||||||
|
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights)
|
||||||
|
|
||||||
|
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||||
|
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||||
|
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||||
|
# therefore:
|
||||||
|
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||||
|
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||||
|
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||||
|
# so let's use tan(), because:
|
||||||
|
# tan is 0.0 at 0,
|
||||||
|
# 1.0 at PI/4, and
|
||||||
|
# inf at PI/2
|
||||||
|
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||||
|
epsilon = 1e-9
|
||||||
|
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||||
|
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||||
|
# todo handle negative weight?
|
||||||
|
|
||||||
|
per_embedding_weights.append(embedding_lerp_weight)
|
||||||
|
|
||||||
|
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||||
|
|
||||||
|
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||||
|
|
||||||
|
# append to batch
|
||||||
|
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||||
|
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||||
|
|
||||||
|
# should have shape (B, 77, 768)
|
||||||
|
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||||
|
|
||||||
|
if should_return_tokens:
|
||||||
|
return batch_z, batch_tokens
|
||||||
|
else:
|
||||||
|
return batch_z
|
||||||
|
|
||||||
|
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||||
|
"""
|
||||||
|
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
||||||
|
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
||||||
|
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
|
||||||
|
(typically 75 tokens + eos/bos markers).
|
||||||
|
|
||||||
|
:param fragments: The strings to convert.
|
||||||
|
:param include_start_and_end_markers:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
|
||||||
|
token_ids_list = self.tokenizer(
|
||||||
|
fragments,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding='do_not_pad',
|
||||||
|
return_tensors=None, # just give me lists of ints
|
||||||
|
)['input_ids']
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for token_ids in token_ids_list:
|
||||||
|
# trim eos/bos
|
||||||
|
token_ids = token_ids[1:-1]
|
||||||
|
# pad for textual inversions with vector length >1
|
||||||
|
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
|
||||||
|
# restrict length to max_length-2 (leaving room for bos/eos)
|
||||||
|
token_ids = token_ids[0:self.max_length - 2]
|
||||||
|
# add back eos/bos if requested
|
||||||
|
if include_start_and_end_markers:
|
||||||
|
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
result.append(token_ids)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||||
|
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||||
|
if normalize:
|
||||||
|
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||||
|
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||||
|
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||||
|
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||||
|
# lerped embeddings has shape (77, 768)
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor):
|
||||||
|
'''
|
||||||
|
Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences
|
||||||
|
together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded
|
||||||
|
or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in
|
||||||
|
weights to match each token.
|
||||||
|
|
||||||
|
:param fragments: Text fragments to tokenize and concatenate. May be empty.
|
||||||
|
:param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5
|
||||||
|
values >1.6 tend to produce garbage output. Must have same length as `fragment`.
|
||||||
|
:return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`.
|
||||||
|
'''
|
||||||
|
if len(fragments) != len(weights):
|
||||||
|
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})")
|
||||||
|
|
||||||
|
# empty is meaningful
|
||||||
|
if len(fragments) == 0:
|
||||||
|
fragments = ['']
|
||||||
|
weights = [1.0]
|
||||||
|
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
||||||
|
all_token_ids = []
|
||||||
|
per_token_weights = []
|
||||||
|
#print("all fragments:", fragments, weights)
|
||||||
|
for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights):
|
||||||
|
# append
|
||||||
|
all_token_ids += this_fragment_token_ids
|
||||||
|
# fill out weights tensor with one float per token
|
||||||
|
per_token_weights += [float(weight)] * len(this_fragment_token_ids)
|
||||||
|
|
||||||
|
# leave room for bos/eos
|
||||||
|
max_token_count_without_bos_eos_markers = self.max_length - 2
|
||||||
|
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
|
||||||
|
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
|
||||||
|
# TODO build nice description string of how the truncation was applied
|
||||||
|
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||||
|
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||||
|
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||||
|
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
||||||
|
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
||||||
|
|
||||||
|
# pad out to a self.max_length-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||||
|
# (typically self.max_length == 77)
|
||||||
|
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
||||||
|
per_token_weights = [1.0] + per_token_weights + [1.0]
|
||||||
|
pad_length = self.max_length - len(all_token_ids)
|
||||||
|
all_token_ids += [self.tokenizer.eos_token_id] * pad_length
|
||||||
|
per_token_weights += [1.0] * pad_length
|
||||||
|
|
||||||
|
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
|
||||||
|
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device)
|
||||||
|
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||||
|
return all_token_ids_tensor, per_token_weights_tensor
|
||||||
|
|
||||||
|
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
|
||||||
|
'''
|
||||||
|
Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
|
||||||
|
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
|
||||||
|
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
|
||||||
|
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
|
||||||
|
where `token_dim` is 768 for SD1 and 1280 for SD2.
|
||||||
|
'''
|
||||||
|
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||||
|
if token_ids.shape != torch.Size([self.max_length]):
|
||||||
|
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||||
|
|
||||||
|
z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
|
||||||
|
return_dict=False)[0]
|
||||||
|
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
|
||||||
|
[self.tokenizer.pad_token_id] * (self.max_length-2) +
|
||||||
|
[self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
|
||||||
|
empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
|
||||||
|
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||||
|
z_delta_from_empty = z - empty_z
|
||||||
|
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||||
|
|
||||||
|
return weighted_z
|
293
ldm/modules/textual_inversion_manager.py
Normal file
293
ldm/modules/textual_inversion_manager.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextualInversion:
|
||||||
|
trigger_string: str
|
||||||
|
embedding: torch.Tensor
|
||||||
|
trigger_token_id: Optional[int] = None
|
||||||
|
pad_token_ids: Optional[list[int]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_vector_length(self) -> int:
|
||||||
|
return self.embedding.shape[0]
|
||||||
|
|
||||||
|
class TextualInversionManager():
|
||||||
|
def __init__(self,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
full_precision: bool=True):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.full_precision = full_precision
|
||||||
|
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||||
|
default_textual_inversions: list[TextualInversion] = []
|
||||||
|
self.textual_inversions = default_textual_inversions
|
||||||
|
|
||||||
|
def load_huggingface_concepts(self, concepts: list[str]):
|
||||||
|
for concept_name in concepts:
|
||||||
|
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||||
|
continue
|
||||||
|
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||||
|
if self.has_textual_inversion_for_trigger_string(trigger):
|
||||||
|
continue
|
||||||
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
|
if not bin_file:
|
||||||
|
continue
|
||||||
|
self.load_textual_inversion(bin_file)
|
||||||
|
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||||
|
|
||||||
|
def get_all_trigger_strings(self) -> list[str]:
|
||||||
|
return [ti.trigger_string for ti in self.textual_inversions]
|
||||||
|
|
||||||
|
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||||
|
try:
|
||||||
|
scan_result = scan_file_path(ckpt_path)
|
||||||
|
if scan_result.infected_files == 1:
|
||||||
|
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||||
|
print('### For your safety, InvokeAI will not load this embed.')
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
embedding_info = self._parse_embedding(ckpt_path)
|
||||||
|
if embedding_info:
|
||||||
|
try:
|
||||||
|
self._add_textual_inversion(embedding_info['name'],
|
||||||
|
embedding_info['embedding'],
|
||||||
|
defer_injecting_tokens=defer_injecting_tokens)
|
||||||
|
except ValueError:
|
||||||
|
print(f' | ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
|
else:
|
||||||
|
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||||
|
|
||||||
|
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
|
||||||
|
"""
|
||||||
|
Add a textual inversion to be recognised.
|
||||||
|
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||||
|
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||||
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
|
"""
|
||||||
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
|
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
|
||||||
|
return
|
||||||
|
if not self.full_precision:
|
||||||
|
embedding = embedding.half()
|
||||||
|
if len(embedding.shape) == 1:
|
||||||
|
embedding = embedding.unsqueeze(0)
|
||||||
|
elif len(embedding.shape) > 2:
|
||||||
|
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ti = TextualInversion(
|
||||||
|
trigger_string=trigger_str,
|
||||||
|
embedding=embedding
|
||||||
|
)
|
||||||
|
if not defer_injecting_tokens:
|
||||||
|
self._inject_tokens_and_assign_embeddings(ti)
|
||||||
|
self.textual_inversions.append(ti)
|
||||||
|
return ti
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
if str(e).startswith('Warning'):
|
||||||
|
print(f">> {str(e)}")
|
||||||
|
else:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||||
|
|
||||||
|
if ti.trigger_token_id is not None:
|
||||||
|
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
||||||
|
|
||||||
|
print(f'DEBUG: Injecting token {ti.trigger_string}')
|
||||||
|
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
||||||
|
|
||||||
|
if ti.embedding_vector_length > 1:
|
||||||
|
# for embeddings with vector length > 1
|
||||||
|
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
|
||||||
|
# todo: batched UI for faster loading when vector length >2
|
||||||
|
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
|
||||||
|
for (i, pad_token_str) in enumerate(pad_token_strings)]
|
||||||
|
else:
|
||||||
|
pad_token_ids = []
|
||||||
|
|
||||||
|
ti.trigger_token_id = trigger_token_id
|
||||||
|
ti.pad_token_ids = pad_token_ids
|
||||||
|
return ti.trigger_token_id
|
||||||
|
|
||||||
|
|
||||||
|
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||||
|
try:
|
||||||
|
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||||
|
return ti is not None
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||||
|
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||||
|
|
||||||
|
|
||||||
|
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||||
|
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||||
|
|
||||||
|
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
|
||||||
|
injected_token_ids = []
|
||||||
|
for ti in self.textual_inversions:
|
||||||
|
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||||
|
if ti.embedding_vector_length > 1:
|
||||||
|
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
|
||||||
|
try:
|
||||||
|
self._inject_tokens_and_assign_embeddings(ti)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f' | ignoring incompatible embedding trigger {ti.trigger_string}')
|
||||||
|
continue
|
||||||
|
injected_token_ids.append(ti.trigger_token_id)
|
||||||
|
injected_token_ids.extend(ti.pad_token_ids)
|
||||||
|
return injected_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||||
|
|
||||||
|
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||||
|
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||||
|
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
||||||
|
"""
|
||||||
|
if len(prompt_token_ids) == 0:
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||||
|
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
|
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||||
|
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
|
||||||
|
prompt_token_ids = prompt_token_ids.copy()
|
||||||
|
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||||
|
if token_id in textual_inversion_trigger_token_ids:
|
||||||
|
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||||
|
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
|
||||||
|
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
|
||||||
|
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
|
||||||
|
if len(embedding.shape) != 1:
|
||||||
|
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
|
||||||
|
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||||
|
if existing_token_id == self.tokenizer.unk_token_id:
|
||||||
|
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||||
|
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
||||||
|
current_token_count = current_embeddings.num_embeddings
|
||||||
|
new_token_count = current_token_count + num_tokens_added
|
||||||
|
# the following call is slow - todo make batched for better performance with vector length >1
|
||||||
|
self.text_encoder.resize_token_embeddings(new_token_count)
|
||||||
|
|
||||||
|
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||||
|
if token_id == self.tokenizer.unk_token_id:
|
||||||
|
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||||
|
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
|
||||||
|
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
|
||||||
|
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||||
|
|
||||||
|
return token_id
|
||||||
|
|
||||||
|
def _parse_embedding(self, embedding_file: str):
|
||||||
|
file_type = embedding_file.split('.')[-1]
|
||||||
|
if file_type == 'pt':
|
||||||
|
return self._parse_embedding_pt(embedding_file)
|
||||||
|
elif file_type == 'bin':
|
||||||
|
return self._parse_embedding_bin(embedding_file)
|
||||||
|
else:
|
||||||
|
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||||
|
|
||||||
|
def _parse_embedding_pt(self, embedding_file):
|
||||||
|
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||||
|
embedding_info = {}
|
||||||
|
|
||||||
|
# Check if valid embedding file
|
||||||
|
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||||
|
|
||||||
|
# Catch variants that do not have the expected keys or values.
|
||||||
|
try:
|
||||||
|
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||||
|
|
||||||
|
# Check num of embeddings and warn user only the first will be used
|
||||||
|
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||||
|
if embedding_info['num_of_embeddings'] > 1:
|
||||||
|
print('>> More than 1 embedding found. Will use the first one')
|
||||||
|
|
||||||
|
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||||
|
except (AttributeError,KeyError):
|
||||||
|
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||||
|
|
||||||
|
embedding_info['embedding'] = embedding
|
||||||
|
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||||
|
embedding_info['token_dim'] = embedding.size()[1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||||
|
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||||
|
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||||
|
except AttributeError:
|
||||||
|
print(">> No Training Details Found. Passing ...")
|
||||||
|
|
||||||
|
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||||
|
# They are actually .bin files
|
||||||
|
elif len(embedding_ckpt.keys())==1:
|
||||||
|
print('>> Detected .bin file masquerading as .pt file')
|
||||||
|
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('>> Invalid embedding format')
|
||||||
|
embedding_info = None
|
||||||
|
|
||||||
|
return embedding_info
|
||||||
|
|
||||||
|
def _parse_embedding_bin(self, embedding_file):
|
||||||
|
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||||
|
embedding_info = {}
|
||||||
|
|
||||||
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
|
print(">> Invalid concepts file")
|
||||||
|
embedding_info = None
|
||||||
|
else:
|
||||||
|
for token in list(embedding_ckpt.keys()):
|
||||||
|
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||||
|
embedding_info['embedding'] = embedding_ckpt[token]
|
||||||
|
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||||
|
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||||
|
|
||||||
|
return embedding_info
|
||||||
|
|
||||||
|
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||||
|
'''
|
||||||
|
This handles the broken .pt file variants. We only know of one at present.
|
||||||
|
'''
|
||||||
|
embedding_info = {}
|
||||||
|
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||||
|
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||||
|
for token in list(embedding_ckpt['string_to_token'].keys()):
|
||||||
|
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||||
|
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||||
|
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||||
|
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||||
|
else:
|
||||||
|
print('>> Invalid embedding format')
|
||||||
|
embedding_info = None
|
||||||
|
|
||||||
|
return embedding_info
|
52
ldm/util.py
52
ldm/util.py
@ -1,17 +1,18 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import math
|
import math
|
||||||
from collections import abc
|
|
||||||
from einops import rearrange
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from threading import Thread
|
from collections import abc
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
from urllib import request
|
||||||
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
|
||||||
@ -250,7 +251,7 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
|
|||||||
if not debug_status:
|
if not debug_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
image_copy = debug_image.copy()
|
image_copy = debug_image.copy().convert("RGBA")
|
||||||
ImageDraw.Draw(image_copy).text(
|
ImageDraw.Draw(image_copy).text(
|
||||||
(5, 5),
|
(5, 5),
|
||||||
debug_text,
|
debug_text,
|
||||||
@ -262,3 +263,32 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
|
|||||||
|
|
||||||
if debug_result:
|
if debug_result:
|
||||||
return image_copy
|
return image_copy
|
||||||
|
|
||||||
|
#-------------------------------------
|
||||||
|
class ProgressBar():
|
||||||
|
def __init__(self,model_name='file'):
|
||||||
|
self.pbar = None
|
||||||
|
self.name = model_name
|
||||||
|
|
||||||
|
def __call__(self, block_num, block_size, total_size):
|
||||||
|
if not self.pbar:
|
||||||
|
self.pbar=tqdm(desc=self.name,
|
||||||
|
initial=0,
|
||||||
|
unit='iB',
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1000,
|
||||||
|
total=total_size)
|
||||||
|
self.pbar.update(block_size)
|
||||||
|
|
||||||
|
def download_with_progress_bar(url:str, dest:Path)->bool:
|
||||||
|
try:
|
||||||
|
if not os.path.exists(dest):
|
||||||
|
os.makedirs((os.path.dirname(dest) or '.'), exist_ok=True)
|
||||||
|
request.urlretrieve(url,dest,ProgressBar(os.path.basename(dest)))
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
@ -8,32 +8,43 @@
|
|||||||
#
|
#
|
||||||
print('Loading Python libraries...\n')
|
print('Loading Python libraries...\n')
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
import warnings
|
|
||||||
import shutil
|
import shutil
|
||||||
from urllib import request
|
import sys
|
||||||
from tqdm import tqdm
|
import traceback
|
||||||
from omegaconf import OmegaConf
|
import warnings
|
||||||
from huggingface_hub import HfFolder, hf_hub_url, login as hf_hub_login
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Dict, Union
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import transformers
|
||||||
|
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||||
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||||
from getpass_asterisk import getpass_asterisk
|
from getpass_asterisk import getpass_asterisk
|
||||||
|
from huggingface_hub import HfFolder, hf_hub_url, login as hf_hub_login, whoami as hf_whoami
|
||||||
|
from huggingface_hub.utils._errors import RevisionNotFoundError
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
from ldm.invoke.readline import generic_completer
|
from ldm.invoke.readline import generic_completer
|
||||||
|
|
||||||
import traceback
|
|
||||||
import requests
|
|
||||||
import clip
|
|
||||||
import transformers
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import torch
|
import torch
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
except ImportError:
|
||||||
|
sys.path.append('.')
|
||||||
|
from ldm.invoke.model_manager import ModelManager
|
||||||
|
|
||||||
#--------------------------globals-----------------------
|
#--------------------------globals-----------------------
|
||||||
Model_dir = 'models'
|
Model_dir = 'models'
|
||||||
Weights_dir = 'ldm/stable-diffusion-v1/'
|
Weights_dir = 'ldm/stable-diffusion-v1/'
|
||||||
@ -150,14 +161,15 @@ will be given the option to view and change your selections.
|
|||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
for ds in Datasets.keys():
|
for ds in Datasets.keys():
|
||||||
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
recommended = Datasets[ds].get('recommended',False)
|
||||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
r_str = '(recommended)' if recommended else ''
|
||||||
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {r_str}')
|
||||||
|
if yes_or_no(' Download?',default_yes=recommended):
|
||||||
datasets[ds]=counter
|
datasets[ds]=counter
|
||||||
counter += 1
|
counter += 1
|
||||||
else:
|
else:
|
||||||
for ds in Datasets.keys():
|
for ds in Datasets.keys():
|
||||||
if Datasets[ds]['recommended']:
|
if Datasets[ds].get('recommended',False):
|
||||||
datasets[ds]=counter
|
datasets[ds]=counter
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
@ -181,7 +193,7 @@ will be given the option to view and change your selections.
|
|||||||
def recommended_datasets()->dict:
|
def recommended_datasets()->dict:
|
||||||
datasets = dict()
|
datasets = dict()
|
||||||
for ds in Datasets.keys():
|
for ds in Datasets.keys():
|
||||||
if Datasets[ds]['recommended']:
|
if Datasets[ds].get('recommended',False):
|
||||||
datasets[ds]=True
|
datasets[ds]=True
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
@ -240,6 +252,7 @@ The license terms are located here:
|
|||||||
print("=" * shutil.get_terminal_size()[0])
|
print("=" * shutil.get_terminal_size()[0])
|
||||||
print('Authenticating to Huggingface')
|
print('Authenticating to Huggingface')
|
||||||
hf_envvars = [ "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN" ]
|
hf_envvars = [ "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN" ]
|
||||||
|
token_found = False
|
||||||
if not (access_token := HfFolder.get_token()):
|
if not (access_token := HfFolder.get_token()):
|
||||||
print(f"Huggingface token not found in cache.")
|
print(f"Huggingface token not found in cache.")
|
||||||
|
|
||||||
@ -257,17 +270,21 @@ The license terms are located here:
|
|||||||
print(f"Huggingface token found in cache.")
|
print(f"Huggingface token found in cache.")
|
||||||
try:
|
try:
|
||||||
HfLogin(access_token)
|
HfLogin(access_token)
|
||||||
|
token_found = True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(f"Login failed due to invalid token found in cache")
|
print(f"Login failed due to invalid token found in cache")
|
||||||
|
|
||||||
if not yes_to_all:
|
if not (yes_to_all or token_found):
|
||||||
print('''
|
print(''' You may optionally enter your Huggingface token now. InvokeAI
|
||||||
You may optionally enter your Huggingface token now. InvokeAI *will* work without it, but some functionality may be limited.
|
*will* work without it but you will not be able to automatically
|
||||||
See https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/#using-a-hugging-face-concept for more information.
|
download some of the Hugging Face style concepts. See
|
||||||
|
https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/#using-a-hugging-face-concept
|
||||||
|
for more information.
|
||||||
|
|
||||||
Visit https://huggingface.co/settings/tokens to generate a token. (Sign up for an account if needed).
|
Visit https://huggingface.co/settings/tokens to generate a token. (Sign up for an account if needed).
|
||||||
|
|
||||||
Paste the token below using Ctrl-Shift-V (macOS/Linux) or right-click (Windows), and/or 'Enter' to continue.
|
Paste the token below using Ctrl-V on macOS/Linux, or Ctrl-Shift-V or right-click on Windows.
|
||||||
|
Alternatively press 'Enter' to skip this step and continue.
|
||||||
You may re-run the configuration script again in the future if you do not wish to set the token right now.
|
You may re-run the configuration script again in the future if you do not wish to set the token right now.
|
||||||
''')
|
''')
|
||||||
again = True
|
again = True
|
||||||
@ -313,34 +330,61 @@ def migrate_models_ckpt():
|
|||||||
os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_weight_datasets(models:dict, access_token:str):
|
def download_weight_datasets(models:dict, access_token:str, precision:str='float32'):
|
||||||
migrate_models_ckpt()
|
migrate_models_ckpt()
|
||||||
successful = dict()
|
successful = dict()
|
||||||
for mod in models.keys():
|
for mod in models.keys():
|
||||||
repo_id = Datasets[mod]['repo_id']
|
print(f'{mod}...',file=sys.stderr,end='')
|
||||||
filename = Datasets[mod]['file']
|
successful[mod] = _download_repo_or_file(Datasets[mod], access_token, precision=precision)
|
||||||
dest = os.path.join(Globals.root,Model_dir,Weights_dir)
|
return successful
|
||||||
success = hf_download_with_resume(
|
|
||||||
|
def _download_repo_or_file(mconfig:DictConfig, access_token:str, precision:str='float32')->Path:
|
||||||
|
path = None
|
||||||
|
if mconfig['format'] == 'ckpt':
|
||||||
|
path = _download_ckpt_weights(mconfig, access_token)
|
||||||
|
else:
|
||||||
|
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||||
|
if 'vae' in mconfig and 'repo_id' in mconfig['vae']:
|
||||||
|
_download_diffusion_weights(mconfig['vae'], access_token, precision=precision)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _download_ckpt_weights(mconfig:DictConfig, access_token:str)->Path:
|
||||||
|
repo_id = mconfig['repo_id']
|
||||||
|
filename = mconfig['file']
|
||||||
|
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||||
|
return hf_download_with_resume(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
model_dir=dest,
|
model_dir=cache_dir,
|
||||||
model_name=filename,
|
model_name=filename,
|
||||||
access_token=access_token
|
access_token=access_token
|
||||||
)
|
)
|
||||||
if success:
|
|
||||||
successful[mod] = True
|
|
||||||
if len(successful) < len(models):
|
|
||||||
print(f'\n\n** There were errors downloading one or more files. **')
|
|
||||||
print('Press any key to try again. Type ^C to quit.\n')
|
|
||||||
input()
|
|
||||||
return None
|
|
||||||
|
|
||||||
keys = ', '.join(successful.keys())
|
def _download_diffusion_weights(mconfig:DictConfig, access_token:str, precision:str='float32'):
|
||||||
print(f'Successfully installed {keys}')
|
repo_id = mconfig['repo_id']
|
||||||
return successful
|
model_class = StableDiffusionGeneratorPipeline if mconfig.get('format',None)=='diffusers' else AutoencoderKL
|
||||||
|
extra_arg_list = [{'revision':'fp16'},{}] if precision=='float16' else [{}]
|
||||||
|
path = None
|
||||||
|
for extra_args in extra_arg_list:
|
||||||
|
try:
|
||||||
|
path = download_from_hf(
|
||||||
|
model_class,
|
||||||
|
repo_id,
|
||||||
|
cache_subdir='diffusers',
|
||||||
|
safety_checker=None,
|
||||||
|
**extra_args,
|
||||||
|
)
|
||||||
|
except OSError as e:
|
||||||
|
if str(e).startswith('fp16 is not a valid'):
|
||||||
|
print(f'Could not fetch half-precision version of model {repo_id}; fetching full-precision instead')
|
||||||
|
else:
|
||||||
|
print(f'An unexpected error occurred while downloading the model: {e})')
|
||||||
|
if path:
|
||||||
|
break
|
||||||
|
return path
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->Path:
|
||||||
model_dest = os.path.join(model_dir, model_name)
|
model_dest = Path(os.path.join(model_dir, model_name))
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
url = hf_hub_url(repo_id, model_name)
|
url = hf_hub_url(repo_id, model_name)
|
||||||
@ -359,7 +403,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
|||||||
|
|
||||||
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||||
print(f'* {model_name}: complete file found. Skipping.')
|
print(f'* {model_name}: complete file found. Skipping.')
|
||||||
return True
|
return model_dest
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
print(f'** An error occurred during downloading {model_name}: {resp.reason}')
|
print(f'** An error occurred during downloading {model_name}: {resp.reason}')
|
||||||
elif exist_size > 0:
|
elif exist_size > 0:
|
||||||
@ -370,7 +414,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
|||||||
try:
|
try:
|
||||||
if total < 2000:
|
if total < 2000:
|
||||||
print(f'*** ERROR DOWNLOADING {model_name}: {resp.text}')
|
print(f'*** ERROR DOWNLOADING {model_name}: {resp.text}')
|
||||||
return False
|
return None
|
||||||
|
|
||||||
with open(model_dest, open_mode) as file, tqdm(
|
with open(model_dest, open_mode) as file, tqdm(
|
||||||
desc=model_name,
|
desc=model_name,
|
||||||
@ -385,8 +429,22 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
|||||||
bar.update(size)
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'An error occurred while downloading {model_name}: {str(e)}')
|
print(f'An error occurred while downloading {model_name}: {str(e)}')
|
||||||
return False
|
return None
|
||||||
|
return model_dest
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------------
|
||||||
|
#---------------------------------------------
|
||||||
|
def is_huggingface_authenticated():
|
||||||
|
# huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError,
|
||||||
|
# maybe other things, not all end-user-friendly.
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
response = hf_whoami()
|
||||||
|
if response.get('id') is not None:
|
||||||
return True
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
||||||
@ -404,7 +462,6 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
|||||||
print(f'Error downloading {label} model')
|
print(f'Error downloading {label} model')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||||
config_file = opt.config_file or Default_config_file
|
config_file = opt.config_file or Default_config_file
|
||||||
@ -441,29 +498,27 @@ def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str
|
|||||||
default_selected = False
|
default_selected = False
|
||||||
|
|
||||||
for model in successfully_downloaded:
|
for model in successfully_downloaded:
|
||||||
a = Datasets[model]['config'].split('/')
|
|
||||||
if a[0] != 'VAE':
|
|
||||||
continue
|
|
||||||
vae_target = a[1] if len(a)>1 else 'default'
|
|
||||||
vaes[vae_target] = Datasets[model]['file']
|
|
||||||
|
|
||||||
for model in successfully_downloaded:
|
|
||||||
if Datasets[model]['config'].startswith('VAE'): # skip VAE entries
|
|
||||||
continue
|
|
||||||
stanza = conf[model] if model in conf else { }
|
stanza = conf[model] if model in conf else { }
|
||||||
|
mod = Datasets[model]
|
||||||
stanza['description'] = Datasets[model]['description']
|
stanza['description'] = mod['description']
|
||||||
stanza['weights'] = os.path.join(Model_dir,Weights_dir,Datasets[model]['file'])
|
stanza['repo_id'] = mod['repo_id']
|
||||||
stanza['config'] = os.path.normpath(os.path.join(SD_Configs, Datasets[model]['config']))
|
stanza['format'] = mod['format']
|
||||||
stanza['width'] = Datasets[model]['width']
|
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||||
stanza['height'] = Datasets[model]['height']
|
# so we no longer require these in INITIAL_MODELS.yaml
|
||||||
stanza.pop('default',None) # this will be set later
|
if 'width' in mod:
|
||||||
if vaes:
|
stanza['width'] = mod['width']
|
||||||
for target in vaes:
|
if 'height' in mod:
|
||||||
if re.search(target, model, flags=re.IGNORECASE):
|
stanza['height'] = mod['height']
|
||||||
stanza['vae'] = os.path.normpath(os.path.join(Model_dir,Weights_dir,vaes[target]))
|
if 'file' in mod:
|
||||||
|
stanza['weights'] = os.path.relpath(successfully_downloaded[model], start=Globals.root)
|
||||||
|
stanza['config'] = os.path.normpath(os.path.join(SD_Configs,mod['config']))
|
||||||
|
if 'vae' in mod:
|
||||||
|
if 'file' in mod['vae']:
|
||||||
|
stanza['vae'] = os.path.normpath(os.path.join(Model_dir, Weights_dir,mod['vae']['file']))
|
||||||
else:
|
else:
|
||||||
stanza['vae'] = os.path.normpath(os.path.join(Model_dir,Weights_dir,vaes['default']))
|
stanza['vae'] = mod['vae']
|
||||||
|
stanza.pop('default',None) # this will be set later
|
||||||
|
|
||||||
# BUG - the first stanza is always the default. User should select.
|
# BUG - the first stanza is always the default. User should select.
|
||||||
if not default_selected:
|
if not default_selected:
|
||||||
stanza['default'] = True
|
stanza['default'] = True
|
||||||
@ -477,17 +532,20 @@ def download_bert():
|
|||||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
from transformers import BertTokenizerFast
|
||||||
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||||
print('...success',file=sys.stderr)
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_from_hf(model_class:object, model_name:str):
|
def download_from_hf(model_class:object, model_name:str, cache_subdir:Path=Path('hub'), **kwargs):
|
||||||
print('',file=sys.stderr) # to prevent tqdm from overwriting
|
print('',file=sys.stderr) # to prevent tqdm from overwriting
|
||||||
return model_class.from_pretrained(model_name,
|
path = global_cache_dir(cache_subdir)
|
||||||
cache_dir=os.path.join(Globals.root,Model_dir,model_name),
|
model = model_class.from_pretrained(model_name,
|
||||||
resume_download=True
|
cache_dir=path,
|
||||||
|
resume_download=True,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
return path if model else None
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_clip():
|
def download_clip():
|
||||||
@ -585,11 +643,13 @@ def download_safety_checker():
|
|||||||
#-------------------------------------
|
#-------------------------------------
|
||||||
def download_weights(opt:dict) -> Union[str, None]:
|
def download_weights(opt:dict) -> Union[str, None]:
|
||||||
|
|
||||||
|
precision = 'float32' if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||||
|
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
models = recommended_datasets()
|
models = recommended_datasets()
|
||||||
access_token = authenticate(opt.yes_to_all)
|
access_token = authenticate(opt.yes_to_all)
|
||||||
if len(models)>0:
|
if len(models)>0:
|
||||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
|
||||||
update_config_file(successfully_downloaded,opt)
|
update_config_file(successfully_downloaded,opt)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -607,11 +667,11 @@ def download_weights(opt:dict) -> Union[str, None]:
|
|||||||
else: # 'skip'
|
else: # 'skip'
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
access_token = authenticate()
|
access_token = authenticate()
|
||||||
|
HfFolder.save_token(access_token)
|
||||||
|
|
||||||
print('\n** DOWNLOADING WEIGHTS **')
|
print('\n** DOWNLOADING WEIGHTS **')
|
||||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
|
||||||
|
|
||||||
update_config_file(successfully_downloaded,opt)
|
update_config_file(successfully_downloaded,opt)
|
||||||
if len(successfully_downloaded) < len(models):
|
if len(successfully_downloaded) < len(models):
|
||||||
@ -738,6 +798,12 @@ def main():
|
|||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=False,
|
default=False,
|
||||||
help='skip downloading the large Stable Diffusion weight files')
|
help='skip downloading the large Stable Diffusion weight files')
|
||||||
|
parser.add_argument('--full-precision',
|
||||||
|
dest='full_precision',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='use 32-bit weights instead of faster 16-bit weights')
|
||||||
parser.add_argument('--yes','-y',
|
parser.add_argument('--yes','-y',
|
||||||
dest='yes_to_all',
|
dest='yes_to_all',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
|
156
scripts/merge_fe.py
Normal file
156
scripts/merge_fe.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import npyscreen
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import traceback
|
||||||
|
import argparse
|
||||||
|
from ldm.invoke.globals import Globals, global_set_root
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class FloatSlider(npyscreen.Slider):
|
||||||
|
# this is supposed to adjust display precision, but doesn't
|
||||||
|
def translate_value(self):
|
||||||
|
stri = "%3.2f / %3.2f" %(self.value, self.out_of)
|
||||||
|
l = (len(str(self.out_of)))*2+4
|
||||||
|
stri = stri.rjust(l)
|
||||||
|
return stri
|
||||||
|
|
||||||
|
class FloatTitleSlider(npyscreen.TitleText):
|
||||||
|
_entry_type = FloatSlider
|
||||||
|
|
||||||
|
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||||
|
|
||||||
|
interpolations = ['weighted_sum',
|
||||||
|
'sigmoid',
|
||||||
|
'inv_sigmoid',
|
||||||
|
'add_difference']
|
||||||
|
|
||||||
|
def afterEditing(self):
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
self.model_names = self.get_model_names()
|
||||||
|
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
name="Select up to three models to merge",
|
||||||
|
value=''
|
||||||
|
)
|
||||||
|
self.model1 = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='First Model:',
|
||||||
|
values=self.model_names,
|
||||||
|
value=0,
|
||||||
|
max_height=len(self.model_names)+1
|
||||||
|
)
|
||||||
|
self.model2 = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Second Model:',
|
||||||
|
values=self.model_names,
|
||||||
|
value=1,
|
||||||
|
max_height=len(self.model_names)+1
|
||||||
|
)
|
||||||
|
models_plus_none = self.model_names.copy()
|
||||||
|
models_plus_none.insert(0,'None')
|
||||||
|
self.model3 = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Third Model:',
|
||||||
|
values=models_plus_none,
|
||||||
|
value=0,
|
||||||
|
max_height=len(self.model_names)+1,
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in [self.model1,self.model2,self.model3]:
|
||||||
|
m.when_value_edited = self.models_changed
|
||||||
|
|
||||||
|
self.merge_method = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Merge Method:',
|
||||||
|
values=self.interpolations,
|
||||||
|
value=0,
|
||||||
|
max_height=len(self.interpolations),
|
||||||
|
)
|
||||||
|
self.alpha = self.add_widget_intelligent(
|
||||||
|
FloatTitleSlider,
|
||||||
|
name='Weight (alpha) to assign to second and third models:',
|
||||||
|
out_of=1,
|
||||||
|
step=0.05,
|
||||||
|
lowest=0,
|
||||||
|
value=0.5,
|
||||||
|
)
|
||||||
|
self.merged_model_name = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name='Name for merged model',
|
||||||
|
value='',
|
||||||
|
)
|
||||||
|
|
||||||
|
def models_changed(self):
|
||||||
|
models = self.model1.values
|
||||||
|
selected_model1 = self.model1.value[0]
|
||||||
|
selected_model2 = self.model2.value[0]
|
||||||
|
selected_model3 = self.model3.value[0]
|
||||||
|
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
|
||||||
|
self.merged_model_name.value = merged_model_name
|
||||||
|
|
||||||
|
if selected_model3 > 0:
|
||||||
|
self.merge_method.values=['add_difference'],
|
||||||
|
self.merged_model_name.value += f'+{models[selected_model3]}'
|
||||||
|
else:
|
||||||
|
self.merge_method.values=self.interpolations
|
||||||
|
self.merge_method.value=0
|
||||||
|
|
||||||
|
def on_ok(self):
|
||||||
|
if self.validate_field_values():
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
self.editing = False
|
||||||
|
else:
|
||||||
|
self.editing = True
|
||||||
|
|
||||||
|
def ok_cancel(self):
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def validate_field_values(self)->bool:
|
||||||
|
bad_fields = []
|
||||||
|
selected_models = set((self.model1.value[0],self.model2.value[0],self.model3.value[0]))
|
||||||
|
if len(selected_models) < 3:
|
||||||
|
bad_fields.append('Please select two or three DIFFERENT models to compare')
|
||||||
|
if len(bad_fields) > 0:
|
||||||
|
message = 'The following problems were detected and must be corrected:'
|
||||||
|
for problem in bad_fields:
|
||||||
|
message += f'\n* {problem}'
|
||||||
|
npyscreen.notify_confirm(message)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_model_names(self)->List[str]:
|
||||||
|
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||||
|
model_names = [name for name in conf.keys() if conf[name].get('format',None)=='diffusers']
|
||||||
|
return sorted(model_names)
|
||||||
|
|
||||||
|
class MyApplication(npyscreen.NPSAppManaged):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def onStart(self):
|
||||||
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
|
self.main = self.addForm('MAIN', mergeModelsForm, name='Merge Models Settings')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||||
|
parser.add_argument(
|
||||||
|
'--root_dir','--root-dir',
|
||||||
|
type=Path,
|
||||||
|
default=Globals.root,
|
||||||
|
help='Path to the invokeai runtime directory',
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
global_set_root(args.root_dir)
|
||||||
|
|
||||||
|
myapplication = MyApplication()
|
||||||
|
myapplication.run()
|
11
scripts/textual_inversion.py
Executable file
11
scripts/textual_inversion.py
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2023, Lincoln Stein @lstein
|
||||||
|
from ldm.invoke.globals import Globals, set_root
|
||||||
|
from ldm.invoke.textual_inversion_training import parse_args, do_textual_inversion_training
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
set_root(args.root_dir or Globals.root)
|
||||||
|
kwargs = vars(args)
|
||||||
|
do_textual_inversion_training(**kwargs)
|
333
scripts/textual_inversion_fe.py
Executable file
333
scripts/textual_inversion_fe.py
Executable file
@ -0,0 +1,333 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import npyscreen
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import traceback
|
||||||
|
from ldm.invoke.globals import Globals, global_set_root
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
TRAINING_DATA = 'training-data'
|
||||||
|
TRAINING_DIR = 'text-inversion-training'
|
||||||
|
CONF_FILE = 'preferences.conf'
|
||||||
|
|
||||||
|
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||||
|
resolutions = [512, 768, 1024]
|
||||||
|
lr_schedulers = [
|
||||||
|
"linear", "cosine", "cosine_with_restarts",
|
||||||
|
"polynomial","constant", "constant_with_warmup"
|
||||||
|
]
|
||||||
|
precisions = ['no','fp16','bf16']
|
||||||
|
learnable_properties = ['object','style']
|
||||||
|
|
||||||
|
def __init__(self, parentApp, name, saved_args=None):
|
||||||
|
self.saved_args = saved_args or {}
|
||||||
|
super().__init__(parentApp, name)
|
||||||
|
|
||||||
|
def afterEditing(self):
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
self.model_names, default = self.get_model_names()
|
||||||
|
default_initializer_token = '★'
|
||||||
|
default_placeholder_token = ''
|
||||||
|
saved_args = self.saved_args
|
||||||
|
|
||||||
|
try:
|
||||||
|
default = self.model_names.index(saved_args['model'])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.model = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Model Name:',
|
||||||
|
values=self.model_names,
|
||||||
|
value=default,
|
||||||
|
max_height=len(self.model_names)+1
|
||||||
|
)
|
||||||
|
self.placeholder_token = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name='Trigger Term:',
|
||||||
|
value='', # saved_args.get('placeholder_token',''), # to restore previous term
|
||||||
|
)
|
||||||
|
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||||
|
self.nextrely -= 1
|
||||||
|
self.nextrelx += 30
|
||||||
|
self.prompt_token = self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
name="Trigger term for use in prompt",
|
||||||
|
value='',
|
||||||
|
)
|
||||||
|
self.nextrelx -= 30
|
||||||
|
self.initializer_token = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name='Initializer:',
|
||||||
|
value=saved_args.get('initializer_token',default_initializer_token),
|
||||||
|
)
|
||||||
|
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Resume from last saved checkpoint",
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
self.learnable_property = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="Learnable property:",
|
||||||
|
values=self.learnable_properties,
|
||||||
|
value=self.learnable_properties.index(saved_args.get('learnable_property','object')),
|
||||||
|
max_height=4,
|
||||||
|
)
|
||||||
|
self.train_data_dir = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFilenameCombo,
|
||||||
|
name='Data Training Directory:',
|
||||||
|
select_dir=True,
|
||||||
|
must_exist=True,
|
||||||
|
value=saved_args.get('train_data_dir',Path(Globals.root) / TRAINING_DATA / default_placeholder_token)
|
||||||
|
)
|
||||||
|
self.output_dir = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFilenameCombo,
|
||||||
|
name='Output Destination Directory:',
|
||||||
|
select_dir=True,
|
||||||
|
must_exist=False,
|
||||||
|
value=saved_args.get('output_dir',Path(Globals.root) / TRAINING_DIR / default_placeholder_token)
|
||||||
|
)
|
||||||
|
self.resolution = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Image resolution (pixels):',
|
||||||
|
values = self.resolutions,
|
||||||
|
value=self.resolutions.index(saved_args.get('resolution',512)),
|
||||||
|
scroll_exit = True,
|
||||||
|
max_height=4,
|
||||||
|
)
|
||||||
|
self.center_crop = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Center crop images before resizing to resolution",
|
||||||
|
value=saved_args.get('center_crop',False)
|
||||||
|
)
|
||||||
|
self.mixed_precision = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Mixed Precision:',
|
||||||
|
values=self.precisions,
|
||||||
|
value=self.precisions.index(saved_args.get('mixed_precision','fp16')),
|
||||||
|
max_height=4,
|
||||||
|
)
|
||||||
|
self.max_train_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Max Training Steps:',
|
||||||
|
out_of=10000,
|
||||||
|
step=500,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get('max_train_steps',3000)
|
||||||
|
)
|
||||||
|
self.train_batch_size = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Batch Size (reduce if you run out of memory):',
|
||||||
|
out_of=50,
|
||||||
|
step=1,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get('train_batch_size',8),
|
||||||
|
)
|
||||||
|
self.learning_rate = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleText,
|
||||||
|
name="Learning Rate:",
|
||||||
|
value=str(saved_args.get('learning_rate','5.0e-04'),)
|
||||||
|
)
|
||||||
|
self.scale_lr = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Scale learning rate by number GPUs, steps and batch size",
|
||||||
|
value=saved_args.get('scale_lr',True),
|
||||||
|
)
|
||||||
|
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||||
|
npyscreen.Checkbox,
|
||||||
|
name="Use xformers acceleration",
|
||||||
|
value=saved_args.get('enable_xformers_memory_efficient_attention',False),
|
||||||
|
)
|
||||||
|
self.lr_scheduler = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name='Learning rate scheduler:',
|
||||||
|
values = self.lr_schedulers,
|
||||||
|
max_height=7,
|
||||||
|
scroll_exit = True,
|
||||||
|
value=self.lr_schedulers.index(saved_args.get('lr_scheduler','constant')),
|
||||||
|
)
|
||||||
|
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Gradient Accumulation Steps:',
|
||||||
|
out_of=10,
|
||||||
|
step=1,
|
||||||
|
lowest=1,
|
||||||
|
value=saved_args.get('gradient_accumulation_steps',4)
|
||||||
|
)
|
||||||
|
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSlider,
|
||||||
|
name='Warmup Steps:',
|
||||||
|
out_of=100,
|
||||||
|
step=1,
|
||||||
|
lowest=0,
|
||||||
|
value=saved_args.get('lr_warmup_steps',0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def initializer_changed(self):
|
||||||
|
placeholder = self.placeholder_token.value
|
||||||
|
self.prompt_token.value = f'(Trigger by using <{placeholder}> in your prompts)'
|
||||||
|
self.train_data_dir.value = Path(Globals.root) / TRAINING_DATA / placeholder
|
||||||
|
self.output_dir.value = Path(Globals.root) / TRAINING_DIR / placeholder
|
||||||
|
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||||
|
|
||||||
|
def on_ok(self):
|
||||||
|
if self.validate_field_values():
|
||||||
|
self.parentApp.setNextForm(None)
|
||||||
|
self.editing = False
|
||||||
|
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||||
|
npyscreen.notify('Launching textual inversion training. This will take a while...')
|
||||||
|
# The module load takes a while, so we do it while the form and message are still up
|
||||||
|
import ldm.invoke.textual_inversion_training
|
||||||
|
else:
|
||||||
|
self.editing = True
|
||||||
|
|
||||||
|
def ok_cancel(self):
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def validate_field_values(self)->bool:
|
||||||
|
bad_fields = []
|
||||||
|
if self.model.value is None:
|
||||||
|
bad_fields.append('Model Name must correspond to a known model in models.yaml')
|
||||||
|
if not re.match('^[a-zA-Z0-9.-]+$',self.placeholder_token.value):
|
||||||
|
bad_fields.append('Trigger term must only contain alphanumeric characters, the dot and hyphen')
|
||||||
|
if self.train_data_dir.value is None:
|
||||||
|
bad_fields.append('Data Training Directory cannot be empty')
|
||||||
|
if self.output_dir.value is None:
|
||||||
|
bad_fields.append('The Output Destination Directory cannot be empty')
|
||||||
|
if len(bad_fields) > 0:
|
||||||
|
message = 'The following problems were detected and must be corrected:'
|
||||||
|
for problem in bad_fields:
|
||||||
|
message += f'\n* {problem}'
|
||||||
|
npyscreen.notify_confirm(message)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_model_names(self)->(List[str],int):
|
||||||
|
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||||
|
model_names = list(conf.keys())
|
||||||
|
defaults = [idx for idx in range(len(model_names)) if 'default' in conf[model_names[idx]]]
|
||||||
|
return (model_names,defaults[0])
|
||||||
|
|
||||||
|
def marshall_arguments(self)->dict:
|
||||||
|
args = dict()
|
||||||
|
|
||||||
|
# the choices
|
||||||
|
args.update(
|
||||||
|
model = self.model_names[self.model.value[0]],
|
||||||
|
resolution = self.resolutions[self.resolution.value[0]],
|
||||||
|
lr_scheduler = self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||||
|
mixed_precision = self.precisions[self.mixed_precision.value[0]],
|
||||||
|
learnable_property = self.learnable_properties[self.learnable_property.value[0]],
|
||||||
|
)
|
||||||
|
|
||||||
|
# all the strings and booleans
|
||||||
|
for attr in ('initializer_token','placeholder_token','train_data_dir',
|
||||||
|
'output_dir','scale_lr','center_crop','enable_xformers_memory_efficient_attention'):
|
||||||
|
args[attr] = getattr(self,attr).value
|
||||||
|
|
||||||
|
# all the integers
|
||||||
|
for attr in ('train_batch_size','gradient_accumulation_steps',
|
||||||
|
'max_train_steps','lr_warmup_steps'):
|
||||||
|
args[attr] = int(getattr(self,attr).value)
|
||||||
|
|
||||||
|
# the floats (just one)
|
||||||
|
args.update(
|
||||||
|
learning_rate = float(self.learning_rate.value)
|
||||||
|
)
|
||||||
|
|
||||||
|
# a special case
|
||||||
|
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||||
|
args['resume_from_checkpoint'] = 'latest'
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
class MyApplication(npyscreen.NPSAppManaged):
|
||||||
|
def __init__(self, saved_args=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ti_arguments=None
|
||||||
|
self.saved_args=saved_args
|
||||||
|
|
||||||
|
def onStart(self):
|
||||||
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
|
self.main = self.addForm('MAIN', textualInversionForm, name='Textual Inversion Settings', saved_args=self.saved_args)
|
||||||
|
|
||||||
|
def copy_to_embeddings_folder(args:dict):
|
||||||
|
'''
|
||||||
|
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||||
|
delete the full model and checkpoints.
|
||||||
|
'''
|
||||||
|
source = Path(args['output_dir'],'learned_embeds.bin')
|
||||||
|
dest_dir_name = args['placeholder_token'].strip('<>')
|
||||||
|
destination = Path(Globals.root,'embeddings',dest_dir_name)
|
||||||
|
os.makedirs(destination,exist_ok=True)
|
||||||
|
print(f'>> Training completed. Copying learned_embeds.bin into {str(destination)}')
|
||||||
|
shutil.copy(source,destination)
|
||||||
|
if (input('Delete training logs and intermediate checkpoints? [y] ') or 'y').startswith(('y','Y')):
|
||||||
|
shutil.rmtree(Path(args['output_dir']))
|
||||||
|
else:
|
||||||
|
print(f'>> Keeping {args["output_dir"]}')
|
||||||
|
|
||||||
|
def save_args(args:dict):
|
||||||
|
'''
|
||||||
|
Save the current argument values to an omegaconf file
|
||||||
|
'''
|
||||||
|
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||||
|
conf = OmegaConf.create(args)
|
||||||
|
OmegaConf.save(config=conf, f=conf_file)
|
||||||
|
|
||||||
|
def previous_args()->dict:
|
||||||
|
'''
|
||||||
|
Get the previous arguments used.
|
||||||
|
'''
|
||||||
|
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||||
|
try:
|
||||||
|
conf = OmegaConf.load(conf_file)
|
||||||
|
conf['placeholder_token'] = conf['placeholder_token'].strip('<>')
|
||||||
|
except:
|
||||||
|
conf= None
|
||||||
|
|
||||||
|
return conf
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||||
|
parser.add_argument(
|
||||||
|
'--root_dir','--root-dir',
|
||||||
|
type=Path,
|
||||||
|
default=Globals.root,
|
||||||
|
help='Path to the invokeai runtime directory',
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
global_set_root(args.root_dir)
|
||||||
|
|
||||||
|
saved_args = previous_args()
|
||||||
|
myapplication = MyApplication(saved_args=saved_args)
|
||||||
|
myapplication.run()
|
||||||
|
|
||||||
|
from ldm.invoke.textual_inversion_training import do_textual_inversion_training
|
||||||
|
if args := myapplication.ti_arguments:
|
||||||
|
os.makedirs(args['output_dir'],exist_ok=True)
|
||||||
|
|
||||||
|
# Automatically add angle brackets around the trigger
|
||||||
|
if not re.match('^<.+>$',args['placeholder_token']):
|
||||||
|
args['placeholder_token'] = f"<{args['placeholder_token']}>"
|
||||||
|
|
||||||
|
args['only_save_embeds'] = True
|
||||||
|
save_args(args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
do_textual_inversion_training(**args)
|
||||||
|
copy_to_embeddings_folder(args)
|
||||||
|
except Exception as e:
|
||||||
|
print('** An exception occurred during training. The exception was:')
|
||||||
|
print(str(e))
|
||||||
|
print('** DETAILS:')
|
||||||
|
print(traceback.format_exc())
|
4
setup.py
4
setup.py
@ -10,6 +10,7 @@ def list_files(directory):
|
|||||||
listing.append(pair)
|
listing.append(pair)
|
||||||
return listing
|
return listing
|
||||||
|
|
||||||
|
|
||||||
def get_version()->str:
|
def get_version()->str:
|
||||||
from ldm.invoke import __version__ as version
|
from ldm.invoke import __version__ as version
|
||||||
return version
|
return version
|
||||||
@ -91,7 +92,8 @@ setup(
|
|||||||
'Topic :: Scientific/Engineering :: Image Processing',
|
'Topic :: Scientific/Engineering :: Image Processing',
|
||||||
],
|
],
|
||||||
scripts = ['scripts/invoke.py','scripts/configure_invokeai.py', 'scripts/sd-metadata.py',
|
scripts = ['scripts/invoke.py','scripts/configure_invokeai.py', 'scripts/sd-metadata.py',
|
||||||
'scripts/preload_models.py', 'scripts/images2prompt.py','scripts/merge_embeddings.py'
|
'scripts/preload_models.py', 'scripts/images2prompt.py','scripts/merge_embeddings.py',
|
||||||
|
'scripts/textual_inversion_fe.py','scripts/textual_inversion.py'
|
||||||
],
|
],
|
||||||
data_files=FRONTEND_FILES,
|
data_files=FRONTEND_FILES,
|
||||||
)
|
)
|
||||||
|
14
tests/inpainting/coyote-inpainting.prompt
Normal file
14
tests/inpainting/coyote-inpainting.prompt
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# 🌻 🌻 🌻 sunflowers 🌻 🌻 🌻
|
||||||
|
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.2
|
||||||
|
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.4
|
||||||
|
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.6
|
||||||
|
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.8
|
||||||
|
a coyote, deep palette knife oil painting, sunflowers, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.99
|
||||||
|
|
||||||
|
# 🌹 🌹 🌹 roses 🌹 🌹 🌹
|
||||||
|
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.2
|
||||||
|
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.4
|
||||||
|
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.6
|
||||||
|
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.8
|
||||||
|
a coyote, deep palette knife oil painting, red roses, plants, desert landscape, award winning -s 50 -S 1234554321 -W 512 -H 512 -C 7.5 -I tests/inpainting/coyote-input.webp -A k_lms -M tests/inpainting/coyote-mask.webp -f 0.99
|
||||||
|
|
BIN
tests/inpainting/coyote-input.webp
Normal file
BIN
tests/inpainting/coyote-input.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
BIN
tests/inpainting/coyote-mask.webp
Normal file
BIN
tests/inpainting/coyote-mask.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.5 KiB |
30
tests/inpainting/original.json
Normal file
30
tests/inpainting/original.json
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"model": "stable diffusion",
|
||||||
|
"model_id": null,
|
||||||
|
"model_hash": "cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516",
|
||||||
|
"app_id": "invoke-ai/InvokeAI",
|
||||||
|
"app_version": "v2.2.3",
|
||||||
|
"image": {
|
||||||
|
"height": 512,
|
||||||
|
"steps": 50,
|
||||||
|
"facetool": "gfpgan",
|
||||||
|
"facetool_strength": 0,
|
||||||
|
"seed": 1948097268,
|
||||||
|
"perlin": 0,
|
||||||
|
"init_mask": null,
|
||||||
|
"width": 512,
|
||||||
|
"upscale": null,
|
||||||
|
"cfg_scale": 7.5,
|
||||||
|
"prompt": [
|
||||||
|
{
|
||||||
|
"prompt": "a coyote, deep palette knife oil painting, red aloe, plants, desert landscape, award winning",
|
||||||
|
"weight": 1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"threshold": 0,
|
||||||
|
"postprocessing": null,
|
||||||
|
"sampler": "k_lms",
|
||||||
|
"variations": [],
|
||||||
|
"type": "txt2img"
|
||||||
|
}
|
||||||
|
}
|
301
tests/test_textual_inversion.py
Normal file
301
tests/test_textual_inversion.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
|
||||||
|
import unittest
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
|
|
||||||
|
KNOWN_WORDS = ['a', 'b', 'c']
|
||||||
|
KNOWN_WORDS_TOKEN_IDS = [0, 1, 2]
|
||||||
|
UNKNOWN_WORDS = ['d', 'e', 'f']
|
||||||
|
|
||||||
|
class DummyEmbeddingsList(list):
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name == 'num_embeddings':
|
||||||
|
return len(self)
|
||||||
|
elif name == 'weight':
|
||||||
|
return self
|
||||||
|
elif name == 'data':
|
||||||
|
return self
|
||||||
|
|
||||||
|
def make_dummy_embedding():
|
||||||
|
return torch.randn([768])
|
||||||
|
|
||||||
|
class DummyTransformer:
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.embeddings = DummyEmbeddingsList([make_dummy_embedding() for _ in range(len(KNOWN_WORDS))])
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, new_size=None):
|
||||||
|
if new_size is None:
|
||||||
|
return self.embeddings
|
||||||
|
else:
|
||||||
|
while len(self.embeddings) > new_size:
|
||||||
|
self.embeddings.pop(-1)
|
||||||
|
while len(self.embeddings) < new_size:
|
||||||
|
self.embeddings.append(make_dummy_embedding())
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
class DummyTokenizer():
|
||||||
|
def __init__(self):
|
||||||
|
self.tokens = KNOWN_WORDS.copy()
|
||||||
|
self.bos_token_id = 49406 # these are what the real CLIPTokenizer has
|
||||||
|
self.eos_token_id = 49407
|
||||||
|
self.pad_token_id = 49407
|
||||||
|
self.unk_token_id = 49407
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, token_str):
|
||||||
|
try:
|
||||||
|
return self.tokens.index(token_str)
|
||||||
|
except ValueError:
|
||||||
|
return self.unk_token_id
|
||||||
|
|
||||||
|
def add_tokens(self, token_str):
|
||||||
|
if token_str in self.tokens:
|
||||||
|
return 0
|
||||||
|
self.tokens.append(token_str)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClipEmbedder:
|
||||||
|
def __init__(self):
|
||||||
|
self.max_length = 77
|
||||||
|
self.transformer = DummyTransformer()
|
||||||
|
self.tokenizer = DummyTokenizer()
|
||||||
|
self.position_embeddings_tensor = torch.randn([77,768], dtype=torch.float32)
|
||||||
|
|
||||||
|
def position_embedding(self, indices: Union[list,torch.Tensor]):
|
||||||
|
if type(indices) is list:
|
||||||
|
indices = torch.tensor(indices, dtype=int)
|
||||||
|
return torch.index_select(self.position_embeddings_tensor, 0, indices)
|
||||||
|
|
||||||
|
|
||||||
|
def was_embedding_overwritten_correctly(tim: TextualInversionManager, overwritten_embedding: torch.Tensor, ti_indices: list, ti_embedding: torch.Tensor) -> bool:
|
||||||
|
return torch.allclose(overwritten_embedding[ti_indices], ti_embedding + tim.clip_embedder.position_embedding(ti_indices))
|
||||||
|
|
||||||
|
|
||||||
|
def make_dummy_textual_inversion_manager():
|
||||||
|
return TextualInversionManager(
|
||||||
|
tokenizer=DummyTokenizer(),
|
||||||
|
text_encoder=DummyTransformer()
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextualInversionManagerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
|
def test_construction(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
|
||||||
|
def test_add_embedding_for_known_token(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
test_embedding = torch.randn([1, 768])
|
||||||
|
test_embedding_name = KNOWN_WORDS[0]
|
||||||
|
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
ti = tim._add_textual_inversion(test_embedding_name, test_embedding)
|
||||||
|
self.assertEqual(ti.trigger_token_id, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# check adding 'test' did not create a new word
|
||||||
|
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||||
|
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||||
|
self.assertEqual(textual_inversion.trigger_token_id, ti.trigger_token_id)
|
||||||
|
|
||||||
|
def test_add_embedding_for_unknown_token(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
test_embedding_1 = torch.randn([1, 768])
|
||||||
|
test_embedding_name_1 = UNKNOWN_WORDS[0]
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
added_token_id_1 = tim._add_textual_inversion(test_embedding_name_1, test_embedding_1).trigger_token_id
|
||||||
|
# new token id should get added on the end
|
||||||
|
self.assertEqual(added_token_id_1, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# check adding did create a new word
|
||||||
|
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
||||||
|
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
||||||
|
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
|
||||||
|
|
||||||
|
# add another one
|
||||||
|
test_embedding_2 = torch.randn([1, 768])
|
||||||
|
test_embedding_name_2 = UNKNOWN_WORDS[1]
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
added_token_id_2 = tim._add_textual_inversion(test_embedding_name_2, test_embedding_2).trigger_token_id
|
||||||
|
self.assertEqual(added_token_id_2, len(KNOWN_WORDS)+1)
|
||||||
|
|
||||||
|
# check adding did create a new word
|
||||||
|
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_2))
|
||||||
|
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_2)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_2))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_2)
|
||||||
|
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_2)
|
||||||
|
|
||||||
|
# check the old one is still there
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name_1))
|
||||||
|
textual_inversion = next(ti for ti in tim.textual_inversions if ti.trigger_token_id == added_token_id_1)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding_1))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name_1)
|
||||||
|
self.assertEqual(textual_inversion.trigger_token_id, added_token_id_1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad_raises_on_eos_bos(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
prompt_token_ids_with_eos_bos = [tim.tokenizer.bos_token_id] + \
|
||||||
|
[KNOWN_WORDS_TOKEN_IDS] + \
|
||||||
|
[tim.tokenizer.eos_token_id]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_with_eos_bos)
|
||||||
|
|
||||||
|
def test_pad_tokens_list_vector_length_1(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||||
|
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
test_embedding_1v = torch.randn([1, 768])
|
||||||
|
test_embedding_1v_token = "<inversion-trigger-vector-length-1>"
|
||||||
|
test_embedding_1v_token_id = tim._add_textual_inversion(test_embedding_1v_token, test_embedding_1v).trigger_token_id
|
||||||
|
self.assertEqual(test_embedding_1v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids_1v_append = prompt_token_ids + [test_embedding_1v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_append)
|
||||||
|
self.assertEqual(prompt_token_ids_1v_append, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids_1v_prepend = [test_embedding_1v_token_id] + prompt_token_ids
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_prepend)
|
||||||
|
self.assertEqual(prompt_token_ids_1v_prepend, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids_1v_insert = prompt_token_ids[0:2] + [test_embedding_1v_token_id] + prompt_token_ids[2:3]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_1v_insert)
|
||||||
|
self.assertEqual(prompt_token_ids_1v_insert, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
def test_pad_tokens_list_vector_length_2(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||||
|
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
test_embedding_2v = torch.randn([2, 768])
|
||||||
|
test_embedding_2v_token = "<inversion-trigger-vector-length-2>"
|
||||||
|
test_embedding_2v_token_id = tim._add_textual_inversion(test_embedding_2v_token, test_embedding_2v).trigger_token_id
|
||||||
|
test_embedding_2v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_2v_token_id).pad_token_ids
|
||||||
|
self.assertEqual(test_embedding_2v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids_2v_append = prompt_token_ids + [test_embedding_2v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_append)
|
||||||
|
self.assertNotEqual(prompt_token_ids_2v_append, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids_2v_prepend = [test_embedding_2v_token_id] + prompt_token_ids
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_prepend)
|
||||||
|
self.assertNotEqual(prompt_token_ids_2v_prepend, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual([test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids_2v_insert = prompt_token_ids[0:2] + [test_embedding_2v_token_id] + prompt_token_ids[2:3]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_2v_insert)
|
||||||
|
self.assertNotEqual(prompt_token_ids_2v_insert, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_2v_token_id] + test_embedding_2v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
def test_pad_tokens_list_vector_length_8(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
prompt_token_ids = KNOWN_WORDS_TOKEN_IDS.copy()
|
||||||
|
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
test_embedding_8v = torch.randn([8, 768])
|
||||||
|
test_embedding_8v_token = "<inversion-trigger-vector-length-8>"
|
||||||
|
test_embedding_8v_token_id = tim._add_textual_inversion(test_embedding_8v_token, test_embedding_8v).trigger_token_id
|
||||||
|
test_embedding_8v_pad_token_ids = tim.get_textual_inversion_for_token_id(test_embedding_8v_token_id).pad_token_ids
|
||||||
|
self.assertEqual(test_embedding_8v_token_id, len(KNOWN_WORDS))
|
||||||
|
|
||||||
|
# at the end
|
||||||
|
prompt_token_ids_8v_append = prompt_token_ids + [test_embedding_8v_token_id]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_append)
|
||||||
|
self.assertNotEqual(prompt_token_ids_8v_append, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# at the start
|
||||||
|
prompt_token_ids_8v_prepend = [test_embedding_8v_token_id] + prompt_token_ids
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_prepend)
|
||||||
|
self.assertNotEqual(prompt_token_ids_8v_prepend, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual([test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids, expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
# in the middle
|
||||||
|
prompt_token_ids_8v_insert = prompt_token_ids[0:2] + [test_embedding_8v_token_id] + prompt_token_ids[2:3]
|
||||||
|
expanded_prompt_token_ids = tim.expand_textual_inversion_token_ids_if_necessary(prompt_token_ids=prompt_token_ids_8v_insert)
|
||||||
|
self.assertNotEqual(prompt_token_ids_8v_insert, expanded_prompt_token_ids)
|
||||||
|
self.assertEqual(prompt_token_ids[0:2] + [test_embedding_8v_token_id] + test_embedding_8v_pad_token_ids + prompt_token_ids[2:3], expanded_prompt_token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deferred_loading(self):
|
||||||
|
tim = make_dummy_textual_inversion_manager()
|
||||||
|
test_embedding = torch.randn([1, 768])
|
||||||
|
test_embedding_name = UNKNOWN_WORDS[0]
|
||||||
|
self.assertFalse(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||||
|
|
||||||
|
pre_embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
|
||||||
|
ti = tim._add_textual_inversion(test_embedding_name, test_embedding, defer_injecting_tokens=True)
|
||||||
|
self.assertIsNone(ti.trigger_token_id)
|
||||||
|
|
||||||
|
# check that a new word is not yet created
|
||||||
|
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count, embeddings_count)
|
||||||
|
|
||||||
|
# check it was added
|
||||||
|
self.assertTrue(tim.has_textual_inversion_for_trigger_string(test_embedding_name))
|
||||||
|
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||||
|
self.assertIsNotNone(textual_inversion)
|
||||||
|
self.assertTrue(torch.equal(textual_inversion.embedding, test_embedding))
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||||
|
self.assertIsNone(textual_inversion.trigger_token_id, ti.trigger_token_id)
|
||||||
|
|
||||||
|
# check it lazy-loads
|
||||||
|
prompt = " ".join([KNOWN_WORDS[0], UNKNOWN_WORDS[0], KNOWN_WORDS[1]])
|
||||||
|
tim.create_deferred_token_ids_for_any_trigger_terms(prompt)
|
||||||
|
|
||||||
|
embeddings_count = len(tim.text_encoder.resize_token_embeddings(None))
|
||||||
|
self.assertEqual(pre_embeddings_count+1, embeddings_count)
|
||||||
|
|
||||||
|
textual_inversion = tim.get_textual_inversion_for_trigger_string(test_embedding_name)
|
||||||
|
self.assertEqual(textual_inversion.trigger_string, test_embedding_name)
|
||||||
|
self.assertEqual(textual_inversion.trigger_token_id, len(KNOWN_WORDS))
|
Loading…
Reference in New Issue
Block a user