fix: Black linting

This commit is contained in:
blessedcoolant 2023-07-29 17:34:43 +12:00
parent 6ed1bf7084
commit 6d82a1019a
3 changed files with 327 additions and 321 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs # Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text) return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation): class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter""" """A prompt input parameter"""

View File

@ -107,6 +107,7 @@
"source": [ "source": [
"# @title 5. Load small ML models required\n", "# @title 5. Load small ML models required\n",
"import gc\n", "import gc\n",
"\n",
"%cd /content/InvokeAI/\n", "%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n", "!python scripts/preload_models.py\n",
"gc.collect()" "gc.collect()"
@ -132,7 +133,8 @@
"source": [ "source": [
"# @title 6. Mount google Drive\n", "# @title 6. Mount google Drive\n",
"from google.colab import drive\n", "from google.colab import drive\n",
"drive.mount('/content/drive')" "\n",
"drive.mount(\"/content/drive\")"
] ]
}, },
{ {
@ -183,7 +185,7 @@
" print(\"❗ Symlink already created\")\n", " print(\"❗ Symlink already created\")\n",
"else:\n", "else:\n",
" src = model_path\n", " src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n", " dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst)\n", " os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")" " print(\"✅ Symbolic link created successfully\")"
] ]

View File

@ -52,17 +52,17 @@
"name": "stdout", "name": "stdout",
"text": [ "text": [
"Cloning into 'latent-diffusion'...\n", "Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n", "remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n", "remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n", "remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n", "remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n", "Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n", "Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n", "Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n", "remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n", "remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n", "remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n", "remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n", "Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n", "Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n", "Obtaining file:///content/taming-transformers\n",
@ -73,9 +73,9 @@
"Installing collected packages: taming-transformers\n", "Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n", " Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n", "Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n", "tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n" "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
] ]
} }
], ],
@ -87,8 +87,9 @@
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n", "!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n", "\n",
"import sys\n", "import sys\n",
"\n",
"sys.path.append(\".\")\n", "sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n", "sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan" "from taming.models import vqgan"
] ]
}, },
@ -299,32 +300,35 @@
" )\n", " )\n",
"\n", "\n",
" for class_label in classes:\n", " for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n", " print(\n",
" f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n", " xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n", " c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
"\n", "\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n", " samples_ddim, _ = sampler.sample(\n",
" S=ddim_steps,\n",
" conditioning=c,\n", " conditioning=c,\n",
" batch_size=n_samples_per_class,\n", " batch_size=n_samples_per_class,\n",
" shape=[3, 64, 64],\n", " shape=[3, 64, 64],\n",
" verbose=False,\n", " verbose=False,\n",
" unconditional_guidance_scale=scale,\n", " unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc,\n", " unconditional_conditioning=uc,\n",
" eta=ddim_eta)\n", " eta=ddim_eta,\n",
" )\n",
"\n", "\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n", " x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n", " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n", " all_samples.append(x_samples_ddim)\n",
"\n", "\n",
"\n", "\n",
"# display as grid\n", "# display as grid\n",
"grid = torch.stack(all_samples, 0)\n", "grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n", "grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n", "grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n", "\n",
"# to image\n", "# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", "grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))" "Image.fromarray(grid.astype(np.uint8))"
], ],
"metadata": { "metadata": {