fix: Black linting

This commit is contained in:
blessedcoolant 2023-07-29 17:34:43 +12:00
parent 6ed1bf7084
commit 6d82a1019a
3 changed files with 327 additions and 321 deletions

View File

@ -6,8 +6,7 @@ from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
@ -68,6 +67,7 @@ class ParamStringInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""

View File

@ -107,6 +107,7 @@
"source": [
"# @title 5. Load small ML models required\n",
"import gc\n",
"\n",
"%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n",
"gc.collect()"
@ -132,7 +133,8 @@
"source": [
"# @title 6. Mount google Drive\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
"\n",
"drive.mount(\"/content/drive\")"
]
},
{
@ -183,7 +185,7 @@
" print(\"❗ Symlink already created\")\n",
"else:\n",
" src = model_path\n",
" dst = '/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt'\n",
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")"
]

View File

@ -52,17 +52,17 @@
"name": "stdout",
"text": [
"Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n",
"remote: Enumerating objects: 992, done.\u001b[K\n",
"remote: Counting objects: 100% (695/695), done.\u001b[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001b[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001b[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n",
"remote: Enumerating objects: 1335, done.\u001b[K\n",
"remote: Counting objects: 100% (525/525), done.\u001b[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001b[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001b[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n",
@ -73,9 +73,9 @@
"Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n"
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\n"
]
}
],
@ -87,8 +87,9 @@
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n",
"import sys\n",
"\n",
"sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n",
"sys.path.append(\"./taming-transformers\")\n",
"from taming.models import vqgan"
]
},
@ -299,32 +300,35 @@
" )\n",
"\n",
" for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n",
" print(\n",
" f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\"\n",
" )\n",
" xc = torch.tensor(n_samples_per_class * [class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
"\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
" samples_ddim, _ = sampler.sample(\n",
" S=ddim_steps,\n",
" conditioning=c,\n",
" batch_size=n_samples_per_class,\n",
" shape=[3, 64, 64],\n",
" verbose=False,\n",
" unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc,\n",
" eta=ddim_eta)\n",
" eta=ddim_eta,\n",
" )\n",
"\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n",
" min=0.0, max=1.0)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n",
"\n",
"\n",
"# display as grid\n",
"grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
"grid = rearrange(grid, \"n b c h w -> (n b) c h w\")\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n",
"# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
"grid = 255.0 * rearrange(grid, \"c h w -> h w c\").cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))"
],
"metadata": {