2023-06-20 23:12:21 +00:00
|
|
|
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
import inspect
|
|
|
|
import re
|
2023-08-17 22:45:25 +00:00
|
|
|
|
|
|
|
# from contextlib import ExitStack
|
2023-10-17 06:23:10 +00:00
|
|
|
from typing import List, Literal, Union
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2023-08-14 03:23:09 +00:00
|
|
|
import torch
|
2023-06-20 23:12:21 +00:00
|
|
|
from diffusers.image_processor import VaeImageProcessor
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
2023-08-14 03:23:09 +00:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
2023-09-15 17:18:00 +00:00
|
|
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
feat: refactor services folder/module structure
Refactor services folder/module structure.
**Motivation**
While working on our services I've repeatedly encountered circular imports and a general lack of clarity regarding where to put things. The structure introduced goes a long way towards resolving those issues, setting us up for a clean structure going forward.
**Services**
Services are now in their own folder with a few files:
- `services/{service_name}/__init__.py`: init as needed, mostly empty now
- `services/{service_name}/{service_name}_base.py`: the base class for the service
- `services/{service_name}/{service_name}_{impl_type}.py`: the default concrete implementation of the service - typically one of `sqlite`, `default`, or `memory`
- `services/{service_name}/{service_name}_common.py`: any common items - models, exceptions, utilities, etc
Though it's a bit verbose to have the service name both as the folder name and the prefix for files, I found it is _extremely_ confusing to have all of the base classes just be named `base.py`. So, at the cost of some verbosity when importing things, I've included the service name in the filename.
There are some minor logic changes. For example, in `InvocationProcessor`, instead of assigning the model manager service to a variable to be used later in the file, the service is used directly via the `Invoker`.
**Shared**
Things that are used across disparate services are in `services/shared/`:
- `default_graphs.py`: previously in `services/`
- `graphs.py`: previously in `services/`
- `paginatation`: generic pagination models used in a few services
- `sqlite`: the `SqliteDatabase` class, other sqlite-specific things
2023-09-24 08:11:07 +00:00
|
|
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
2023-08-14 03:23:09 +00:00
|
|
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
|
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-06-21 01:24:25 +00:00
|
|
|
from ...backend.model_management import ONNXModelPatcher
|
2023-08-14 03:23:09 +00:00
|
|
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
2023-07-18 16:35:07 +00:00
|
|
|
from ...backend.util import choose_torch_device
|
2023-08-14 03:23:09 +00:00
|
|
|
from .baseinvocation import (
|
|
|
|
BaseInvocation,
|
|
|
|
BaseInvocationOutput,
|
|
|
|
FieldDescriptions,
|
|
|
|
Input,
|
2023-09-06 23:30:30 +00:00
|
|
|
InputField,
|
2023-08-14 03:23:09 +00:00
|
|
|
InvocationContext,
|
|
|
|
OutputField,
|
|
|
|
UIComponent,
|
2023-08-15 11:45:40 +00:00
|
|
|
UIType,
|
2023-10-17 06:23:10 +00:00
|
|
|
WithMetadata,
|
|
|
|
WithWorkflow,
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
invocation,
|
|
|
|
invocation_output,
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
2023-09-06 23:30:30 +00:00
|
|
|
from .controlnet_image_processors import ControlField
|
2023-09-15 17:18:00 +00:00
|
|
|
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
2023-08-14 03:23:09 +00:00
|
|
|
from .model import ClipField, ModelInfo, UNetField, VaeField
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
ORT_TO_NP_TYPE = {
|
|
|
|
"tensor(bool)": np.bool_,
|
|
|
|
"tensor(int8)": np.int8,
|
|
|
|
"tensor(uint8)": np.uint8,
|
|
|
|
"tensor(int16)": np.int16,
|
|
|
|
"tensor(uint16)": np.uint16,
|
|
|
|
"tensor(int32)": np.int32,
|
|
|
|
"tensor(uint32)": np.uint32,
|
|
|
|
"tensor(int64)": np.int64,
|
|
|
|
"tensor(uint64)": np.uint64,
|
|
|
|
"tensor(float16)": np.float16,
|
|
|
|
"tensor(float)": np.float32,
|
|
|
|
"tensor(double)": np.float64,
|
|
|
|
}
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-09-04 08:11:56 +00:00
|
|
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
2023-06-20 23:12:21 +00:00
|
|
|
class ONNXPromptInvocation(BaseInvocation):
|
2023-08-14 03:23:09 +00:00
|
|
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
|
|
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-08-14 09:41:29 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
2023-06-20 23:12:21 +00:00
|
|
|
tokenizer_info = context.services.model_manager.get_model(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
**self.clip.tokenizer.model_dump(),
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
|
|
|
text_encoder_info = context.services.model_manager.get_model(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
**self.clip.text_encoder.model_dump(),
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
2023-08-17 22:45:25 +00:00
|
|
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
2023-07-28 13:46:44 +00:00
|
|
|
loras = [
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
(
|
|
|
|
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
|
|
|
lora.weight,
|
|
|
|
)
|
2023-07-28 13:46:44 +00:00
|
|
|
for lora in self.clip.loras
|
|
|
|
]
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
ti_list = []
|
|
|
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
|
|
|
name = trigger[1:-1]
|
|
|
|
try:
|
2023-08-03 23:01:05 +00:00
|
|
|
ti_list.append(
|
|
|
|
(
|
|
|
|
name,
|
|
|
|
context.services.model_manager.get_model(
|
|
|
|
model_name=name,
|
|
|
|
base_model=self.clip.text_encoder.base_model,
|
|
|
|
model_type=ModelType.TextualInversion,
|
|
|
|
).context.model,
|
|
|
|
)
|
|
|
|
)
|
2023-06-20 23:12:21 +00:00
|
|
|
except Exception:
|
2023-07-28 13:46:44 +00:00
|
|
|
# print(e)
|
|
|
|
# import traceback
|
|
|
|
# print(traceback.format_exc())
|
|
|
|
print(f'Warn: trigger: "{trigger}" not found')
|
2023-07-27 19:20:38 +00:00
|
|
|
if loras or ti_list:
|
|
|
|
text_encoder.release_session()
|
2023-09-06 23:30:30 +00:00
|
|
|
with (
|
|
|
|
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
|
|
|
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
|
|
|
):
|
2023-06-20 23:12:21 +00:00
|
|
|
text_encoder.create_session()
|
|
|
|
|
2023-06-21 01:24:25 +00:00
|
|
|
# copy from
|
|
|
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
2023-06-20 23:12:21 +00:00
|
|
|
text_inputs = tokenizer(
|
|
|
|
self.prompt,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=tokenizer.model_max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="np",
|
|
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
"""
|
|
|
|
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
|
|
|
|
|
|
|
if not np.array_equal(text_input_ids, untruncated_ids):
|
|
|
|
removed_text = self.tokenizer.batch_decode(
|
|
|
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
|
|
|
)
|
|
|
|
logger.warning(
|
|
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
|
|
)
|
|
|
|
"""
|
|
|
|
|
|
|
|
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
|
|
|
|
|
|
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
|
|
|
|
|
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
|
|
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
|
|
|
|
2023-08-14 09:41:29 +00:00
|
|
|
return ConditioningOutput(
|
2023-06-20 23:12:21 +00:00
|
|
|
conditioning=ConditioningField(
|
|
|
|
conditioning_name=conditioning_name,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
# Text to image
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation(
|
|
|
|
"t2l_onnx",
|
|
|
|
title="ONNX Text to Latents",
|
|
|
|
tags=["latents", "inference", "txt2img", "onnx"],
|
|
|
|
category="latents",
|
2023-09-04 08:11:56 +00:00
|
|
|
version="1.0.0",
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
)
|
2023-06-20 23:12:21 +00:00
|
|
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
|
|
|
"""Generates latents from conditionings."""
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
positive_conditioning: ConditioningField = InputField(
|
|
|
|
description=FieldDescriptions.positive_cond,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
negative_conditioning: ConditioningField = InputField(
|
|
|
|
description=FieldDescriptions.negative_cond,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
noise: LatentsField = InputField(
|
|
|
|
description=FieldDescriptions.noise,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
|
|
cfg_scale: Union[float, List[float]] = InputField(
|
|
|
|
default=7.5,
|
|
|
|
ge=1,
|
|
|
|
description=FieldDescriptions.cfg_scale,
|
|
|
|
)
|
|
|
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
2023-08-17 08:58:01 +00:00
|
|
|
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
|
|
|
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
|
|
|
unet: UNetField = InputField(
|
|
|
|
description=FieldDescriptions.unet,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
control: Union[ControlField, list[ControlField]] = InputField(
|
2023-08-14 03:23:09 +00:00
|
|
|
default=None,
|
|
|
|
description=FieldDescriptions.control,
|
|
|
|
)
|
|
|
|
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
|
|
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
2023-06-20 23:12:21 +00:00
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
@field_validator("cfg_scale")
|
2023-06-20 23:12:21 +00:00
|
|
|
def ge_one(cls, v):
|
|
|
|
"""validate that all cfg_scale values are >= 1"""
|
|
|
|
if isinstance(v, list):
|
|
|
|
for i in v:
|
|
|
|
if i < 1:
|
2023-07-28 13:46:44 +00:00
|
|
|
raise ValueError("cfg_scale must be greater than 1")
|
2023-06-20 23:12:21 +00:00
|
|
|
else:
|
|
|
|
if v < 1:
|
2023-07-28 13:46:44 +00:00
|
|
|
raise ValueError("cfg_scale must be greater than 1")
|
2023-06-20 23:12:21 +00:00
|
|
|
return v
|
|
|
|
|
2023-06-21 01:24:25 +00:00
|
|
|
# based on
|
|
|
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
2023-06-20 23:12:21 +00:00
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
|
|
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
|
|
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
2023-07-28 13:46:44 +00:00
|
|
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
2023-07-18 16:35:07 +00:00
|
|
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
2023-06-20 23:12:21 +00:00
|
|
|
if isinstance(c, torch.Tensor):
|
|
|
|
c = c.cpu().numpy()
|
|
|
|
if isinstance(uc, torch.Tensor):
|
|
|
|
uc = uc.cpu().numpy()
|
2023-07-18 16:35:07 +00:00
|
|
|
device = torch.device(choose_torch_device())
|
2023-06-20 23:12:21 +00:00
|
|
|
prompt_embeds = np.concatenate([uc, c])
|
|
|
|
|
|
|
|
latents = context.services.latents.get(self.noise.latents_name)
|
|
|
|
if isinstance(latents, torch.Tensor):
|
|
|
|
latents = latents.cpu().numpy()
|
|
|
|
|
|
|
|
# TODO: better execution device handling
|
2023-07-20 17:15:45 +00:00
|
|
|
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
# get the initial random noise unless the user supplied it
|
|
|
|
do_classifier_free_guidance = True
|
2023-07-28 13:46:44 +00:00
|
|
|
# latents_dtype = prompt_embeds.dtype
|
|
|
|
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
|
|
|
# if latents.shape != latents_shape:
|
2023-06-20 23:12:21 +00:00
|
|
|
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
|
|
|
|
|
|
scheduler = get_scheduler(
|
|
|
|
context=context,
|
|
|
|
scheduler_info=self.unet.scheduler,
|
|
|
|
scheduler_name=self.scheduler,
|
2023-08-14 03:02:33 +00:00
|
|
|
seed=0, # TODO: refactor this node
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
|
|
|
|
2023-07-18 16:35:07 +00:00
|
|
|
def torch2numpy(latent: torch.Tensor):
|
|
|
|
return latent.cpu().numpy()
|
|
|
|
|
|
|
|
def numpy2torch(latent, device):
|
|
|
|
return torch.from_numpy(latent).to(device)
|
|
|
|
|
|
|
|
def dispatch_progress(
|
2023-07-28 13:46:44 +00:00
|
|
|
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
|
|
|
) -> None:
|
2023-07-18 16:35:07 +00:00
|
|
|
stable_diffusion_step_callback(
|
|
|
|
context=context,
|
|
|
|
intermediate_state=intermediate_state,
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
node=self.model_dump(),
|
2023-07-18 16:35:07 +00:00
|
|
|
source_node_id=source_node_id,
|
|
|
|
)
|
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
scheduler.set_timesteps(self.steps)
|
|
|
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
|
|
|
|
|
|
|
extra_step_kwargs = dict()
|
|
|
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
|
|
|
extra_step_kwargs.update(
|
|
|
|
eta=0.0,
|
|
|
|
)
|
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-08-17 22:45:25 +00:00
|
|
|
with unet_info as unet: # , ExitStack() as stack:
|
2023-07-28 13:46:44 +00:00
|
|
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
|
|
|
loras = [
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
(
|
|
|
|
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
|
|
|
lora.weight,
|
|
|
|
)
|
2023-07-28 13:46:44 +00:00
|
|
|
for lora in self.unet.loras
|
|
|
|
]
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-07-27 19:20:38 +00:00
|
|
|
if loras:
|
|
|
|
unet.release_session()
|
2023-06-20 23:12:21 +00:00
|
|
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
2023-07-28 13:46:44 +00:00
|
|
|
# TODO:
|
2023-07-21 16:16:24 +00:00
|
|
|
_, _, h, w = latents.shape
|
|
|
|
unet.create_session(h, w)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
timestep_dtype = next(
|
2023-07-17 20:27:33 +00:00
|
|
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
|
|
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
|
|
|
for i in tqdm(range(len(scheduler.timesteps))):
|
|
|
|
t = scheduler.timesteps[i]
|
|
|
|
# expand the latents if we are doing classifier free guidance
|
|
|
|
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
2023-07-18 16:35:07 +00:00
|
|
|
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
2023-06-20 23:12:21 +00:00
|
|
|
latent_model_input = latent_model_input.cpu().numpy()
|
|
|
|
|
|
|
|
# predict the noise residual
|
|
|
|
timestep = np.array([t], dtype=timestep_dtype)
|
|
|
|
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
|
|
|
noise_pred = noise_pred[0]
|
|
|
|
|
|
|
|
# perform guidance
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
|
|
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
|
|
scheduler_output = scheduler.step(
|
2023-07-18 16:35:07 +00:00
|
|
|
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
|
|
|
)
|
|
|
|
latents = torch2numpy(scheduler_output.prev_sample)
|
|
|
|
|
|
|
|
state = PipelineIntermediateState(
|
2023-07-28 13:46:44 +00:00
|
|
|
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
2023-07-28 13:46:44 +00:00
|
|
|
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
# call the callback, if provided
|
2023-07-28 13:46:44 +00:00
|
|
|
# if callback is not None and i % callback_steps == 0:
|
2023-06-20 23:12:21 +00:00
|
|
|
# callback(i, t, latents)
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
2023-06-20 23:12:21 +00:00
|
|
|
context.services.latents.save(name, latents)
|
2023-07-17 20:27:33 +00:00
|
|
|
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
# Latent to image
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation(
|
|
|
|
"l2i_onnx",
|
|
|
|
title="ONNX Latents to Image",
|
|
|
|
tags=["latents", "image", "vae", "onnx"],
|
|
|
|
category="image",
|
2023-09-04 08:11:56 +00:00
|
|
|
version="1.0.0",
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
)
|
2023-10-17 06:23:10 +00:00
|
|
|
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
2023-06-20 23:12:21 +00:00
|
|
|
"""Generates an image from latents."""
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
latents: LatentsField = InputField(
|
|
|
|
description=FieldDescriptions.denoised_latents,
|
|
|
|
input=Input.Connection,
|
2023-07-28 13:46:44 +00:00
|
|
|
)
|
2023-08-14 03:23:09 +00:00
|
|
|
vae: VaeField = InputField(
|
|
|
|
description=FieldDescriptions.vae,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
|
|
|
|
|
|
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
|
|
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
|
|
|
|
|
|
|
vae_info = context.services.model_manager.get_model(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
**self.vae.vae.model_dump(),
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# clear memory as vae decode can request a lot
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
with vae_info as vae:
|
|
|
|
vae.create_session()
|
|
|
|
|
2023-06-21 01:24:25 +00:00
|
|
|
# copied from
|
|
|
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
2023-06-20 23:12:21 +00:00
|
|
|
latents = 1 / 0.18215 * latents
|
|
|
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
|
|
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
2023-07-28 13:46:44 +00:00
|
|
|
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
2023-06-20 23:12:21 +00:00
|
|
|
|
|
|
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
|
|
|
image = image.transpose((0, 2, 3, 1))
|
|
|
|
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
image_dto = context.services.images.create(
|
|
|
|
image=image,
|
|
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
|
|
image_category=ImageCategory.GENERAL,
|
|
|
|
node_id=self.id,
|
|
|
|
session_id=context.graph_execution_state_id,
|
2023-07-18 18:27:54 +00:00
|
|
|
is_intermediate=self.is_intermediate,
|
2023-10-17 06:23:10 +00:00
|
|
|
metadata=self.metadata,
|
2023-08-24 11:42:32 +00:00
|
|
|
workflow=self.workflow,
|
2023-06-20 23:12:21 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
return ImageOutput(
|
|
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
|
|
width=image_dto.width,
|
|
|
|
height=image_dto.height,
|
|
|
|
)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation_output("model_loader_output_onnx")
|
2023-06-20 23:12:21 +00:00
|
|
|
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
|
|
|
"""Model loader output"""
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
|
|
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
|
|
|
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
|
|
|
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-20 23:12:21 +00:00
|
|
|
|
2023-07-14 18:24:15 +00:00
|
|
|
class OnnxModelField(BaseModel):
|
|
|
|
"""Onnx model field"""
|
|
|
|
|
|
|
|
model_name: str = Field(description="Name of the model")
|
|
|
|
base_model: BaseModelType = Field(description="Base model")
|
2023-07-19 02:40:27 +00:00
|
|
|
model_type: ModelType = Field(description="Model Type")
|
2023-07-14 18:24:15 +00:00
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-09-04 08:11:56 +00:00
|
|
|
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
2023-07-14 18:24:15 +00:00
|
|
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
|
|
|
"""Loads a main model, outputting its submodels."""
|
|
|
|
|
2023-08-14 03:23:09 +00:00
|
|
|
model: OnnxModelField = InputField(
|
2023-08-15 11:45:40 +00:00
|
|
|
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
2023-07-14 18:24:15 +00:00
|
|
|
|
|
|
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
|
|
base_model = self.model.base_model
|
|
|
|
model_name = self.model.model_name
|
|
|
|
model_type = ModelType.ONNX
|
|
|
|
|
|
|
|
# TODO: not found exceptions
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
):
|
|
|
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
|
|
|
|
|
|
"""
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=self.model_name,
|
|
|
|
model_type=SDModelType.Diffusers,
|
|
|
|
submodel=SDModelType.Tokenizer,
|
|
|
|
):
|
|
|
|
raise Exception(
|
|
|
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
|
|
)
|
|
|
|
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=self.model_name,
|
|
|
|
model_type=SDModelType.Diffusers,
|
|
|
|
submodel=SDModelType.TextEncoder,
|
|
|
|
):
|
|
|
|
raise Exception(
|
|
|
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
|
|
)
|
|
|
|
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
|
|
model_name=self.model_name,
|
|
|
|
model_type=SDModelType.Diffusers,
|
|
|
|
submodel=SDModelType.UNet,
|
|
|
|
):
|
|
|
|
raise Exception(
|
|
|
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
|
|
)
|
|
|
|
"""
|
|
|
|
|
|
|
|
return ONNXModelLoaderOutput(
|
|
|
|
unet=UNetField(
|
|
|
|
unet=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.UNet,
|
|
|
|
),
|
|
|
|
scheduler=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Scheduler,
|
|
|
|
),
|
|
|
|
loras=[],
|
|
|
|
),
|
|
|
|
clip=ClipField(
|
|
|
|
tokenizer=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.Tokenizer,
|
|
|
|
),
|
|
|
|
text_encoder=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.TextEncoder,
|
|
|
|
),
|
|
|
|
loras=[],
|
2023-07-16 03:56:48 +00:00
|
|
|
skipped_layers=0,
|
2023-07-14 18:24:15 +00:00
|
|
|
),
|
|
|
|
vae_decoder=VaeField(
|
|
|
|
vae=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.VaeDecoder,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
vae_encoder=VaeField(
|
|
|
|
vae=ModelInfo(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=SubModelType.VaeEncoder,
|
|
|
|
),
|
2023-07-28 13:46:44 +00:00
|
|
|
),
|
|
|
|
)
|