2024-02-29 06:02:28 +00:00
|
|
|
from typing import Iterator, List, Optional, Tuple, Union, cast
|
2023-04-25 00:48:44 +00:00
|
|
|
|
2023-07-03 14:08:10 +00:00
|
|
|
import torch
|
2023-07-17 22:49:45 +00:00
|
|
|
from compel import Compel, ReturnedEmbeddingsType
|
2023-07-05 02:37:16 +00:00
|
|
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
2024-02-29 06:02:28 +00:00
|
|
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
|
|
|
|
2024-07-03 16:20:35 +00:00
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
2024-03-08 15:48:45 +00:00
|
|
|
from invokeai.app.invocations.fields import (
|
|
|
|
ConditioningField,
|
|
|
|
FieldDescriptions,
|
|
|
|
Input,
|
|
|
|
InputField,
|
|
|
|
OutputField,
|
2024-04-08 18:16:22 +00:00
|
|
|
TensorField,
|
2024-03-08 15:48:45 +00:00
|
|
|
UIComponent,
|
|
|
|
)
|
2024-07-03 16:20:35 +00:00
|
|
|
from invokeai.app.invocations.model import CLIPField
|
2024-01-13 12:23:16 +00:00
|
|
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
2024-02-05 06:16:35 +00:00
|
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
2024-02-27 20:20:14 +00:00
|
|
|
from invokeai.app.util.ti_utils import generate_ti_list
|
2024-02-17 16:45:32 +00:00
|
|
|
from invokeai.backend.lora import LoRAModelRaw
|
2024-02-18 06:27:42 +00:00
|
|
|
from invokeai.backend.model_patcher import ModelPatcher
|
2023-09-08 15:00:11 +00:00
|
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
2023-08-14 03:23:09 +00:00
|
|
|
BasicConditioningInfo,
|
2024-01-14 23:41:25 +00:00
|
|
|
ConditioningFieldData,
|
2023-08-14 03:23:09 +00:00
|
|
|
SDXLConditioningInfo,
|
|
|
|
)
|
2024-04-15 13:12:49 +00:00
|
|
|
from invokeai.backend.util.devices import TorchDevice
|
2023-08-14 03:23:09 +00:00
|
|
|
|
2024-02-05 06:16:35 +00:00
|
|
|
# unconditioned: Optional[torch.Tensor]
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-11 15:19:36 +00:00
|
|
|
|
|
|
|
# class ConditioningAlgo(str, Enum):
|
|
|
|
# Compose = "compose"
|
|
|
|
# ComposeEx = "compose_ex"
|
|
|
|
# PerpNeg = "perp_neg"
|
2023-04-25 00:48:44 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
@invocation(
|
|
|
|
"compel",
|
|
|
|
title="Prompt",
|
|
|
|
tags=["prompt", "compel"],
|
|
|
|
category="conditioning",
|
2024-03-08 15:48:45 +00:00
|
|
|
version="1.2.0",
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
)
|
2023-04-25 00:48:44 +00:00
|
|
|
class CompelInvocation(BaseInvocation):
|
2023-05-05 18:09:29 +00:00
|
|
|
"""Parse prompt using compel package to conditioning."""
|
2023-04-25 00:48:44 +00:00
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
prompt: str = InputField(
|
|
|
|
default="",
|
|
|
|
description=FieldDescriptions.compel_prompt,
|
|
|
|
ui_component=UIComponent.Textarea,
|
|
|
|
)
|
2024-03-06 08:42:47 +00:00
|
|
|
clip: CLIPField = InputField(
|
2023-08-14 03:23:09 +00:00
|
|
|
title="CLIP",
|
|
|
|
description=FieldDescriptions.clip,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
2024-04-08 18:16:22 +00:00
|
|
|
mask: Optional[TensorField] = InputField(
|
2024-03-08 15:48:45 +00:00
|
|
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
|
|
|
)
|
2023-04-25 00:48:44 +00:00
|
|
|
|
2023-07-03 14:08:10 +00:00
|
|
|
@torch.no_grad()
|
2024-02-05 06:16:35 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
2024-03-06 08:37:15 +00:00
|
|
|
tokenizer_info = context.models.load(self.clip.tokenizer)
|
|
|
|
text_encoder_info = context.models.load(self.clip.text_encoder)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-02-06 03:56:32 +00:00
|
|
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
2023-07-05 02:37:16 +00:00
|
|
|
for lora in self.clip.loras:
|
2024-03-06 08:37:15 +00:00
|
|
|
lora_info = context.models.load(lora.lora)
|
2024-02-06 03:56:32 +00:00
|
|
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
|
|
|
yield (lora_info.model, lora.weight)
|
2023-07-05 02:37:16 +00:00
|
|
|
del lora_info
|
|
|
|
return
|
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-02-27 20:20:14 +00:00
|
|
|
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
with (
|
2024-05-24 17:06:09 +00:00
|
|
|
# apply all patches while the model is on the target device
|
2024-06-06 13:53:35 +00:00
|
|
|
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
2024-05-24 17:06:09 +00:00
|
|
|
tokenizer_info as tokenizer,
|
2024-06-06 13:53:35 +00:00
|
|
|
ModelPatcher.apply_lora_text_encoder(
|
|
|
|
text_encoder,
|
|
|
|
loras=_lora_loader(),
|
|
|
|
model_state_dict=model_state_dict,
|
|
|
|
),
|
2023-11-13 20:42:10 +00:00
|
|
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
2024-05-24 17:06:09 +00:00
|
|
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
|
|
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
|
|
|
patched_tokenizer,
|
|
|
|
ti_manager,
|
|
|
|
),
|
2023-09-08 15:00:11 +00:00
|
|
|
):
|
2024-02-27 17:26:51 +00:00
|
|
|
assert isinstance(text_encoder, CLIPTextModel)
|
2024-05-24 17:06:09 +00:00
|
|
|
assert isinstance(tokenizer, CLIPTokenizer)
|
2023-07-05 02:37:16 +00:00
|
|
|
compel = Compel(
|
2024-05-24 17:06:09 +00:00
|
|
|
tokenizer=patched_tokenizer,
|
2023-07-05 02:37:16 +00:00
|
|
|
text_encoder=text_encoder,
|
|
|
|
textual_inversion_manager=ti_manager,
|
2024-04-15 13:12:49 +00:00
|
|
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
2023-07-30 12:20:59 +00:00
|
|
|
truncate_long_prompts=False,
|
2023-07-05 02:37:16 +00:00
|
|
|
)
|
2023-05-12 01:24:29 +00:00
|
|
|
|
2023-07-05 02:37:16 +00:00
|
|
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
2023-05-12 01:24:29 +00:00
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
if context.config.get().log_tokenization:
|
2024-05-24 17:06:09 +00:00
|
|
|
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-03-11 22:22:49 +00:00
|
|
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2023-07-18 13:20:25 +00:00
|
|
|
c = c.detach().to("cpu")
|
|
|
|
|
2024-03-11 22:22:49 +00:00
|
|
|
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
conditioning_name = context.conditioning.save(conditioning_data)
|
2024-03-08 15:48:45 +00:00
|
|
|
return ConditioningOutput(
|
|
|
|
conditioning=ConditioningField(
|
|
|
|
conditioning_name=conditioning_name,
|
|
|
|
mask=self.mask,
|
|
|
|
)
|
|
|
|
)
|
2023-04-25 00:48:44 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-17 22:49:45 +00:00
|
|
|
class SDXLPromptInvocationBase:
|
2024-02-06 03:56:32 +00:00
|
|
|
"""Prompt processor for SDXL models."""
|
|
|
|
|
2023-08-10 03:19:22 +00:00
|
|
|
def run_clip_compel(
|
|
|
|
self,
|
2024-02-05 06:16:35 +00:00
|
|
|
context: InvocationContext,
|
2024-03-06 08:42:47 +00:00
|
|
|
clip_field: CLIPField,
|
2023-08-10 03:19:22 +00:00
|
|
|
prompt: str,
|
|
|
|
get_pooled: bool,
|
|
|
|
lora_prefix: str,
|
|
|
|
zero_on_empty: bool,
|
2024-03-11 22:22:49 +00:00
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
2024-03-06 08:37:15 +00:00
|
|
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
|
|
|
text_encoder_info = context.models.load(clip_field.text_encoder)
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-08-07 15:37:06 +00:00
|
|
|
# return zero on empty
|
|
|
|
if prompt == "" and zero_on_empty:
|
2024-02-06 03:56:32 +00:00
|
|
|
cpu_text_encoder = text_encoder_info.model
|
|
|
|
assert isinstance(cpu_text_encoder, torch.nn.Module)
|
2023-08-07 15:37:06 +00:00
|
|
|
c = torch.zeros(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
(
|
|
|
|
1,
|
|
|
|
cpu_text_encoder.config.max_position_embeddings,
|
|
|
|
cpu_text_encoder.config.hidden_size,
|
|
|
|
),
|
2024-02-06 03:56:32 +00:00
|
|
|
dtype=cpu_text_encoder.dtype,
|
2023-07-17 22:49:45 +00:00
|
|
|
)
|
|
|
|
if get_pooled:
|
2023-08-07 15:37:06 +00:00
|
|
|
c_pooled = torch.zeros(
|
|
|
|
(1, cpu_text_encoder.config.hidden_size),
|
|
|
|
dtype=c.dtype,
|
|
|
|
)
|
2023-07-17 22:49:45 +00:00
|
|
|
else:
|
|
|
|
c_pooled = None
|
2024-03-11 22:22:49 +00:00
|
|
|
return c, c_pooled
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2024-02-06 03:56:32 +00:00
|
|
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
2023-07-11 15:19:36 +00:00
|
|
|
for lora in clip_field.loras:
|
2024-03-06 08:37:15 +00:00
|
|
|
lora_info = context.models.load(lora.lora)
|
2024-02-06 03:56:32 +00:00
|
|
|
lora_model = lora_info.model
|
|
|
|
assert isinstance(lora_model, LoRAModelRaw)
|
|
|
|
yield (lora_model, lora.weight)
|
2023-07-11 15:19:36 +00:00
|
|
|
del lora_info
|
|
|
|
return
|
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2024-02-27 20:20:14 +00:00
|
|
|
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
with (
|
2024-05-24 17:06:09 +00:00
|
|
|
# apply all patches while the model is on the target device
|
2024-06-06 13:53:35 +00:00
|
|
|
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
2024-05-24 17:06:09 +00:00
|
|
|
tokenizer_info as tokenizer,
|
2024-06-06 13:53:35 +00:00
|
|
|
ModelPatcher.apply_lora(
|
|
|
|
text_encoder,
|
|
|
|
loras=_lora_loader(),
|
|
|
|
prefix=lora_prefix,
|
|
|
|
model_state_dict=state_dict,
|
|
|
|
),
|
2023-11-13 20:42:10 +00:00
|
|
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
2024-05-24 17:06:09 +00:00
|
|
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
|
|
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
|
|
|
patched_tokenizer,
|
|
|
|
ti_manager,
|
|
|
|
),
|
2023-09-08 15:00:11 +00:00
|
|
|
):
|
2024-02-29 06:02:28 +00:00
|
|
|
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
2024-05-24 17:06:09 +00:00
|
|
|
assert isinstance(tokenizer, CLIPTokenizer)
|
|
|
|
|
2024-02-29 06:02:28 +00:00
|
|
|
text_encoder = cast(CLIPTextModel, text_encoder)
|
2023-07-11 15:19:36 +00:00
|
|
|
compel = Compel(
|
2024-05-24 17:06:09 +00:00
|
|
|
tokenizer=patched_tokenizer,
|
2023-07-11 15:19:36 +00:00
|
|
|
text_encoder=text_encoder,
|
|
|
|
textual_inversion_manager=ti_manager,
|
2024-04-15 13:12:49 +00:00
|
|
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
2023-07-30 12:20:59 +00:00
|
|
|
truncate_long_prompts=False, # TODO:
|
2023-07-17 22:49:45 +00:00
|
|
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
2023-08-16 17:21:04 +00:00
|
|
|
requires_pooled=get_pooled,
|
2023-07-11 15:19:36 +00:00
|
|
|
)
|
|
|
|
|
2023-07-17 22:49:45 +00:00
|
|
|
conjunction = Compel.parse_prompt_string(prompt)
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
if context.config.get().log_tokenization:
|
2023-07-17 22:49:45 +00:00
|
|
|
# TODO: better logging for and syntax
|
2024-05-24 17:06:09 +00:00
|
|
|
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-07-17 22:49:45 +00:00
|
|
|
# TODO: ask for optimizations? to not run text_encoder twice
|
2024-03-11 22:22:49 +00:00
|
|
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
2023-07-17 22:49:45 +00:00
|
|
|
if get_pooled:
|
|
|
|
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
|
|
|
else:
|
|
|
|
c_pooled = None
|
2023-07-11 15:19:36 +00:00
|
|
|
|
|
|
|
del tokenizer
|
|
|
|
del text_encoder
|
|
|
|
del tokenizer_info
|
|
|
|
del text_encoder_info
|
|
|
|
|
2023-07-18 13:20:25 +00:00
|
|
|
c = c.detach().to("cpu")
|
|
|
|
if c_pooled is not None:
|
|
|
|
c_pooled = c_pooled.detach().to("cpu")
|
|
|
|
|
2024-03-11 22:22:49 +00:00
|
|
|
return c, c_pooled
|
2023-07-17 22:49:45 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation(
|
|
|
|
"sdxl_compel_prompt",
|
|
|
|
title="SDXL Prompt",
|
|
|
|
tags=["sdxl", "compel", "prompt"],
|
|
|
|
category="conditioning",
|
2024-03-08 15:48:45 +00:00
|
|
|
version="1.2.0",
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
)
|
2023-07-17 22:49:45 +00:00
|
|
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|
|
|
"""Parse prompt using compel package to conditioning."""
|
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
prompt: str = InputField(
|
|
|
|
default="",
|
|
|
|
description=FieldDescriptions.compel_prompt,
|
|
|
|
ui_component=UIComponent.Textarea,
|
|
|
|
)
|
|
|
|
style: str = InputField(
|
|
|
|
default="",
|
|
|
|
description=FieldDescriptions.compel_prompt,
|
|
|
|
ui_component=UIComponent.Textarea,
|
|
|
|
)
|
2023-08-14 03:23:09 +00:00
|
|
|
original_width: int = InputField(default=1024, description="")
|
|
|
|
original_height: int = InputField(default=1024, description="")
|
|
|
|
crop_top: int = InputField(default=0, description="")
|
|
|
|
crop_left: int = InputField(default=0, description="")
|
|
|
|
target_width: int = InputField(default=1024, description="")
|
|
|
|
target_height: int = InputField(default=1024, description="")
|
2024-03-06 08:42:47 +00:00
|
|
|
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
|
|
|
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
2024-04-08 18:16:22 +00:00
|
|
|
mask: Optional[TensorField] = InputField(
|
2024-03-08 15:48:45 +00:00
|
|
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
|
|
|
)
|
2023-07-11 15:19:36 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
2024-02-05 06:16:35 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
2024-03-11 22:22:49 +00:00
|
|
|
c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True)
|
2023-07-17 22:49:45 +00:00
|
|
|
if self.style.strip() == "":
|
2024-03-11 22:22:49 +00:00
|
|
|
c2, c2_pooled = self.run_clip_compel(
|
2023-08-13 09:28:39 +00:00
|
|
|
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
|
|
|
)
|
2023-07-17 22:49:45 +00:00
|
|
|
else:
|
2024-03-11 22:22:49 +00:00
|
|
|
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
2023-07-17 22:49:45 +00:00
|
|
|
|
|
|
|
original_size = (self.original_height, self.original_width)
|
|
|
|
crop_coords = (self.crop_top, self.crop_left)
|
|
|
|
target_size = (self.target_height, self.target_width)
|
|
|
|
|
|
|
|
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
2023-07-11 15:19:36 +00:00
|
|
|
|
2023-08-31 01:07:44 +00:00
|
|
|
# [1, 77, 768], [1, 154, 1280]
|
|
|
|
if c1.shape[1] < c2.shape[1]:
|
|
|
|
c1 = torch.cat(
|
|
|
|
[
|
|
|
|
c1,
|
|
|
|
torch.zeros(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]),
|
|
|
|
device=c1.device,
|
|
|
|
dtype=c1.dtype,
|
2023-08-31 01:07:44 +00:00
|
|
|
),
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
|
|
|
|
elif c1.shape[1] > c2.shape[1]:
|
|
|
|
c2 = torch.cat(
|
|
|
|
[
|
|
|
|
c2,
|
|
|
|
torch.zeros(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]),
|
|
|
|
device=c2.device,
|
|
|
|
dtype=c2.dtype,
|
2023-08-31 01:07:44 +00:00
|
|
|
),
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
|
2024-02-06 03:56:32 +00:00
|
|
|
assert c2_pooled is not None
|
2023-07-11 15:19:36 +00:00
|
|
|
conditioning_data = ConditioningFieldData(
|
|
|
|
conditionings=[
|
|
|
|
SDXLConditioningInfo(
|
2024-03-11 22:22:49 +00:00
|
|
|
embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids
|
2023-07-11 15:19:36 +00:00
|
|
|
)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
conditioning_name = context.conditioning.save(conditioning_data)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-03-08 15:48:45 +00:00
|
|
|
return ConditioningOutput(
|
|
|
|
conditioning=ConditioningField(
|
|
|
|
conditioning_name=conditioning_name,
|
|
|
|
mask=self.mask,
|
|
|
|
)
|
|
|
|
)
|
2023-04-25 00:48:44 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation(
|
|
|
|
"sdxl_refiner_compel_prompt",
|
|
|
|
title="SDXL Refiner Prompt",
|
|
|
|
tags=["sdxl", "compel", "prompt"],
|
|
|
|
category="conditioning",
|
2024-03-19 11:08:16 +00:00
|
|
|
version="1.1.1",
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
)
|
2023-07-17 22:49:45 +00:00
|
|
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
2023-07-11 15:19:36 +00:00
|
|
|
"""Parse prompt using compel package to conditioning."""
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
style: str = InputField(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
default="",
|
|
|
|
description=FieldDescriptions.compel_prompt,
|
|
|
|
ui_component=UIComponent.Textarea,
|
2023-08-14 03:23:09 +00:00
|
|
|
) # TODO: ?
|
|
|
|
original_width: int = InputField(default=1024, description="")
|
|
|
|
original_height: int = InputField(default=1024, description="")
|
|
|
|
crop_top: int = InputField(default=0, description="")
|
|
|
|
crop_left: int = InputField(default=0, description="")
|
|
|
|
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
2024-03-06 08:42:47 +00:00
|
|
|
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
2023-07-11 15:19:36 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
2024-02-05 06:16:35 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
2023-07-31 20:18:02 +00:00
|
|
|
# TODO: if there will appear lora for refiner - write proper prefix
|
2024-03-11 22:22:49 +00:00
|
|
|
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
2023-07-16 03:00:37 +00:00
|
|
|
|
|
|
|
original_size = (self.original_height, self.original_width)
|
|
|
|
crop_coords = (self.crop_top, self.crop_left)
|
|
|
|
|
|
|
|
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
|
|
|
|
2024-02-06 03:56:32 +00:00
|
|
|
assert c2_pooled is not None
|
2023-07-16 03:00:37 +00:00
|
|
|
conditioning_data = ConditioningFieldData(
|
2024-03-11 22:22:49 +00:00
|
|
|
conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)]
|
2023-07-16 03:00:37 +00:00
|
|
|
)
|
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
conditioning_name = context.conditioning.save(conditioning_data)
|
2023-07-16 03:00:37 +00:00
|
|
|
|
2024-01-13 12:23:16 +00:00
|
|
|
return ConditioningOutput.build(conditioning_name)
|
2023-07-16 03:00:37 +00:00
|
|
|
|
2023-07-17 22:49:45 +00:00
|
|
|
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation_output("clip_skip_output")
|
2024-03-06 08:42:47 +00:00
|
|
|
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
|
|
|
"""CLIP skip node output"""
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2024-03-06 08:42:47 +00:00
|
|
|
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
2023-07-06 14:39:49 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
@invocation(
|
|
|
|
"clip_skip",
|
|
|
|
title="CLIP Skip",
|
|
|
|
tags=["clipskip", "clip", "skip"],
|
|
|
|
category="conditioning",
|
2024-03-19 11:08:16 +00:00
|
|
|
version="1.1.0",
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
)
|
2024-03-06 08:42:47 +00:00
|
|
|
class CLIPSkipInvocation(BaseInvocation):
|
2023-07-06 14:39:49 +00:00
|
|
|
"""Skip layers in clip text_encoder model."""
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2024-03-06 08:42:47 +00:00
|
|
|
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
feat(nodes): JIT graph nodes validation
We use pydantic to validate a union of valid invocations when instantiating a graph.
Previously, we constructed the union while creating the `Graph` class. This introduces a dependency on the order of imports.
For example, consider a setup where we have 3 invocations in the app:
- Python executes the module where `FirstInvocation` is defined, registering `FirstInvocation`.
- Python executes the module where `SecondInvocation` is defined, registering `SecondInvocation`.
- Python executes the module where `Graph` is defined. A union of invocations is created and used to define the `Graph.nodes` field. The union contains `FirstInvocation` and `SecondInvocation`.
- Python executes the module where `ThirdInvocation` is defined, registering `ThirdInvocation`.
- A graph is created that includes `ThirdInvocation`. Pydantic validates the graph using the union, which does not know about `ThirdInvocation`, raising a `ValidationError` about an unknown invocation type.
This scenario has been particularly problematic in tests, where we may create invocations dynamically. The test files have to be structured in such a way that the imports happen in the right order. It's a major pain.
This PR refactors the validation of graph nodes to resolve this issue:
- `BaseInvocation` gets a new method `get_typeadapter`. This builds a pydantic `TypeAdapter` for the union of all registered invocations, caching it after the first call.
- `Graph.nodes`'s type is widened to `dict[str, BaseInvocation]`. This actually is a nice bonus, because we get better type hints whenever we reference `some_graph.nodes`.
- A "plain" field validator takes over the validation logic for `Graph.nodes`. "Plain" validators totally override pydantic's own validation logic. The validator grabs the `TypeAdapter` from `BaseInvocation`, then validates each node with it. The validation is identical to the previous implementation - we get the same errors.
`BaseInvocationOutput` gets the same treatment.
2024-02-17 00:22:08 +00:00
|
|
|
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
2023-07-18 14:26:45 +00:00
|
|
|
|
2024-03-06 08:42:47 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
2023-07-06 14:39:49 +00:00
|
|
|
self.clip.skipped_layers += self.skipped_layers
|
2024-03-06 08:42:47 +00:00
|
|
|
return CLIPSkipInvocationOutput(
|
2023-07-06 14:39:49 +00:00
|
|
|
clip=self.clip,
|
|
|
|
)
|
|
|
|
|
2023-04-25 00:48:44 +00:00
|
|
|
|
|
|
|
def get_max_token_count(
|
2024-02-10 23:09:45 +00:00
|
|
|
tokenizer: CLIPTokenizer,
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
2024-02-10 23:09:45 +00:00
|
|
|
truncate_if_too_long: bool = False,
|
2023-07-05 02:37:16 +00:00
|
|
|
) -> int:
|
2023-04-25 00:48:44 +00:00
|
|
|
if type(prompt) is Blend:
|
|
|
|
blend: Blend = prompt
|
2023-06-04 13:30:54 +00:00
|
|
|
return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
|
|
|
|
elif type(prompt) is Conjunction:
|
|
|
|
conjunction: Conjunction = prompt
|
|
|
|
return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
|
2023-04-25 00:48:44 +00:00
|
|
|
else:
|
2023-07-05 02:37:16 +00:00
|
|
|
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
2023-04-25 00:48:44 +00:00
|
|
|
|
|
|
|
|
2024-02-10 23:09:45 +00:00
|
|
|
def get_tokens_for_prompt_object(
|
|
|
|
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
|
|
|
|
) -> List[str]:
|
2023-04-25 00:48:44 +00:00
|
|
|
if type(parsed_prompt) is Blend:
|
|
|
|
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
|
|
|
|
|
|
|
text_fragments = [
|
2023-09-08 15:00:11 +00:00
|
|
|
(
|
|
|
|
x.text
|
|
|
|
if type(x) is Fragment
|
|
|
|
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
|
|
|
)
|
2023-04-25 00:48:44 +00:00
|
|
|
for x in parsed_prompt.children
|
|
|
|
]
|
|
|
|
text = " ".join(text_fragments)
|
2024-02-10 23:09:45 +00:00
|
|
|
tokens: List[str] = tokenizer.tokenize(text)
|
2023-04-25 00:48:44 +00:00
|
|
|
if truncate_if_too_long:
|
|
|
|
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
|
|
|
tokens = tokens[0:max_tokens_length]
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
2024-02-10 23:09:45 +00:00
|
|
|
def log_tokenization_for_conjunction(
|
|
|
|
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
|
|
|
) -> None:
|
2023-06-04 13:30:54 +00:00
|
|
|
display_label_prefix = display_label_prefix or ""
|
|
|
|
for i, p in enumerate(c.prompts):
|
2023-07-05 02:37:16 +00:00
|
|
|
if len(c.prompts) > 1:
|
2023-06-04 13:30:54 +00:00
|
|
|
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
|
|
|
else:
|
2024-02-10 23:09:45 +00:00
|
|
|
assert display_label_prefix is not None
|
2023-06-04 13:30:54 +00:00
|
|
|
this_display_label_prefix = display_label_prefix
|
|
|
|
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
|
|
|
|
|
|
|
|
2024-02-10 23:09:45 +00:00
|
|
|
def log_tokenization_for_prompt_object(
|
|
|
|
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
|
|
|
) -> None:
|
2023-04-25 00:48:44 +00:00
|
|
|
display_label_prefix = display_label_prefix or ""
|
|
|
|
if type(p) is Blend:
|
|
|
|
blend: Blend = p
|
|
|
|
for i, c in enumerate(blend.prompts):
|
|
|
|
log_tokenization_for_prompt_object(
|
|
|
|
c,
|
|
|
|
tokenizer,
|
|
|
|
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
|
|
|
)
|
|
|
|
elif type(p) is FlattenedPrompt:
|
|
|
|
flattened_prompt: FlattenedPrompt = p
|
|
|
|
if flattened_prompt.wants_cross_attention_control:
|
|
|
|
original_fragments = []
|
|
|
|
edited_fragments = []
|
|
|
|
for f in flattened_prompt.children:
|
|
|
|
if type(f) is CrossAttentionControlSubstitute:
|
|
|
|
original_fragments += f.original
|
|
|
|
edited_fragments += f.edited
|
|
|
|
else:
|
|
|
|
original_fragments.append(f)
|
|
|
|
edited_fragments.append(f)
|
|
|
|
|
|
|
|
original_text = " ".join([x.text for x in original_fragments])
|
|
|
|
log_tokenization_for_text(
|
|
|
|
original_text,
|
|
|
|
tokenizer,
|
|
|
|
display_label=f"{display_label_prefix}(.swap originals)",
|
|
|
|
)
|
|
|
|
edited_text = " ".join([x.text for x in edited_fragments])
|
|
|
|
log_tokenization_for_text(
|
|
|
|
edited_text,
|
|
|
|
tokenizer,
|
|
|
|
display_label=f"{display_label_prefix}(.swap replacements)",
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
text = " ".join([x.text for x in flattened_prompt.children])
|
|
|
|
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
|
|
|
|
|
|
|
|
2024-02-10 23:09:45 +00:00
|
|
|
def log_tokenization_for_text(
|
|
|
|
text: str,
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
display_label: Optional[str] = None,
|
|
|
|
truncate_if_too_long: Optional[bool] = False,
|
|
|
|
) -> None:
|
2023-04-25 00:48:44 +00:00
|
|
|
"""shows how the prompt is tokenized
|
|
|
|
# usually tokens have '</w>' to indicate end-of-word,
|
|
|
|
# but for readability it has been replaced with ' '
|
|
|
|
"""
|
|
|
|
tokens = tokenizer.tokenize(text)
|
|
|
|
tokenized = ""
|
|
|
|
discarded = ""
|
|
|
|
usedTokens = 0
|
|
|
|
totalTokens = len(tokens)
|
|
|
|
|
|
|
|
for i in range(0, totalTokens):
|
|
|
|
token = tokens[i].replace("</w>", " ")
|
|
|
|
# alternate color
|
|
|
|
s = (usedTokens % 6) + 1
|
|
|
|
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
|
|
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
|
|
else:
|
|
|
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
|
|
usedTokens += 1
|
|
|
|
|
|
|
|
if usedTokens > 0:
|
|
|
|
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
|
|
|
print(f"{tokenized}\x1b[0m")
|
|
|
|
|
|
|
|
if discarded != "":
|
|
|
|
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
|
|
|
print(f"{discarded}\x1b[0m")
|