Merge branch 'main' into set-timestep-mps-fix

This commit is contained in:
Millun Atluri 2023-07-28 16:12:07 +10:00 committed by GitHub
commit 7f81a95b20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 518 additions and 210 deletions

View File

@ -123,7 +123,7 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
You must have Python 3.9 through 3.11 installed on your machine. Earlier or
later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)

View File

@ -40,10 +40,8 @@ experimental versions later.
this, open up a command-line window ("Terminal" on Linux and
Macintosh, "Command" or "Powershell" on Windows) and type `python
--version`. If Python is installed, it will print out the version
number. If it is version `3.9.*` or `3.10.*`, you meet
requirements. We do not recommend using Python 3.11 or higher,
as not all the libraries that InvokeAI depends on work properly
with this version.
number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
requirements.
!!! warning "What to do if you have an unsupported version"

View File

@ -32,7 +32,7 @@ gaming):
* **Python**
version 3.9 or 3.10 (3.11 is not recommended).
version 3.9 through 3.11
* **CUDA Tools**
@ -65,7 +65,7 @@ gaming):
To install InvokeAI with virtual environments and the PIP package
manager, please follow these steps:
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install
1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
procedure depends on this and will not work with other versions:
```bash

View File

@ -9,16 +9,20 @@ cd $scriptdir
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
MINIMUM_PYTHON_VERSION=3.9.0
MAXIMUM_PYTHON_VERSION=3.11.0
MAXIMUM_PYTHON_VERSION=3.11.100
PYTHON=""
for candidate in python3.10 python3.9 python3 python ; do
for candidate in python3.11 python3.10 python3.9 python3 python ; do
if ppath=`which $candidate`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
python_version=$($ppath -V | awk '{ print $2 }')
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath
break
fi
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath
break
fi
fi
fi
done

View File

@ -90,7 +90,7 @@ async def update_model(
new_name=info.model_name,
new_base=info.base_model,
)
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}")
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model

View File

@ -3,6 +3,7 @@ import asyncio
import sys
from inspect import signature
import logging
import uvicorn
import socket
@ -210,11 +211,25 @@ def invoke_api():
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop)
# Use access_log to turn off logging
config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop=loop,
log_level=app_config.log_level,
)
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname)
l.handlers.clear()
for ch in logger.handlers:
l.addHandler(ch)
loop.run_until_complete(server.serve())

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
@ -311,70 +311,71 @@ class TextToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -402,82 +403,83 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
with SilenceWarnings(): # this quenches NSFW nag from diffusers
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
# TODO: Verify the noise is the right size
initial_latents = (
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
self.strength,
device=unet.device,
)
# TODO: Verify the noise is the right size
initial_latents = (
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
self.strength,
device=unet.device,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -490,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image"

View File

@ -401,7 +401,11 @@ class ModelManager(object):
base_model: BaseModelType,
model_type: ModelType,
) -> str:
return f"{base_model}/{model_type}/{model_name}"
# In 3.11, the behavior of (str,enum) when interpolated into a
# string has changed. The next two lines are defensive.
base_model = BaseModelType(base_model)
model_type = ModelType(model_type)
return f"{base_model.value}/{model_type.value}/{model_name}"
@classmethod
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:

View File

@ -57,7 +57,7 @@ class LoRAModel(ModelBase):
@classproperty
def save_to_config(cls) -> bool:
return False
return True
@classmethod
def detect_format(cls, path: str):

View File

@ -1,4 +1,4 @@
import{A as g,fS as Xe,z as x,a4 as Ba,fT as Ea,af as ca,aj as c,fU as b,al as Da,fV as t,fW as Ra,fX as h,fY as ba,fZ as ja,f_ as Ha,aI as Wa,f$ as Va,ad as La,g0 as qa}from"./index-89941396.js";import{n,o as Sr,p as Oa,T as Na,q as Ga,s as Ua,t as Ya,v as Xa,w as Ka,x as Za,y as Ja,z as Qa,A as et,B as rt,D as at,E as tt,F as ot,G as nt,e as it,M as lt}from"./MantineProvider-8184f020.js";var va=String.raw,ua=va`
import{A as g,fS as Xe,z as x,a4 as Ba,fT as Ea,af as ca,aj as c,fU as b,al as Da,fV as t,fW as Ra,fX as h,fY as ba,fZ as ja,f_ as Ha,aI as Wa,f$ as Va,ad as La,g0 as qa}from"./index-5a784cdd.js";import{n,o as Sr,p as Oa,T as Na,q as Ga,s as Ua,t as Ya,v as Xa,w as Ka,x as Za,y as Ja,z as Qa,A as et,B as rt,D as at,E as tt,F as ot,G as nt,e as it,M as lt}from"./MantineProvider-ea42d3d1.js";var va=String.raw,ua=va`
:root,
:host {
--chakra-vh: 100vh;

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-89941396.js"></script>
<script type="module" crossorigin src="./assets/index-5a784cdd.js"></script>
</head>
<body dir="ltr">

View File

@ -340,6 +340,7 @@
"allModels": "All Models",
"checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers",
"loraModels": "LoRAs",
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",

View File

@ -1,3 +1,5 @@
import { components } from 'services/api/schema';
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = {
'sdxl-refiner': 'Stable Diffusion XL Refiner',
};
export const MODEL_TYPE_SHORT_MAP = {
'sd-1': 'SD1',
'sd-2': 'SD2',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
export const clipSkipMap = {
'sd-1': {
maxClip: 12,
@ -23,3 +32,12 @@ export const clipSkipMap = {
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
};
type LoRAModelFormatMap = {
[key in components['schemas']['LoRAModelFormat']]: string;
};
export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = {
lycoris: 'LyCORIS',
diffusers: 'Diffusers',
};

View File

@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react';
import { useState } from 'react';
import {
MainModelConfigEntity,
DiffusersModelConfigEntity,
LoRAModelConfigEntity,
useGetMainModelsQuery,
useGetLoRAModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const { loraModel } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const model = mainModel ? mainModel : loraModel;
return (
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
@ -30,7 +41,7 @@ export default function ModelManagerPanel() {
}
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => {
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />;
return (
<DiffusersModelEdit
key={model.id}
model={model as DiffusersModelConfigEntity}
/>
);
}
if (model?.model_type === 'lora') {
return <LoRAModelEdit key={model.id} model={model} />;
}
return (

View File

@ -0,0 +1,137 @@
import { Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
LORA_MODEL_FORMAT_MAP,
MODEL_TYPE_MAP,
} from 'features/parameters/types/constants';
import {
LoRAModelConfigEntity,
useUpdateLoRAModelsMutation,
} from 'services/api/endpoints/models';
import { LoRAModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
type LoRAModelEditProps = {
model: LoRAModelConfigEntity;
};
export default function LoRAModelEdit(props: LoRAModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const loraEditForm = useForm<LoRAModelConfig>({
initialValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'lora',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: model.model_format,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = useCallback(
(values: LoRAModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
loraEditForm.setValues(payload as LoRAModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
loraEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
dispatch,
loraEditForm,
model.base_model,
model.model_name,
t,
updateLoRAModel,
]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} Model {' '}
{LORA_MODEL_FORMAT_MAP[model.model_format]} format
</Text>
</Flex>
<Divider />
<form
onSubmit={loraEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...loraEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
{...loraEditForm.getInputProps('description')}
/>
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...loraEditForm.getInputProps('path')}
/>
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
);
}

View File

@ -9,6 +9,8 @@ import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
@ -20,22 +22,42 @@ type ModelListProps = {
type ModelFormat = 'images' | 'checkpoint' | 'diffusers';
type ModelType = 'main' | 'lora';
type CombinedModelFormat = ModelFormat | 'lora';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images');
useState<CombinedModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
}),
});
@ -68,6 +90,13 @@ const ModelList = (props: ModelListProps) => {
>
{t('modelManager.checkpointModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('lora')}
isChecked={modelFormatFilter === 'lora'}
>
{t('modelManager.loraModels')}
</IAIButton>
</ButtonGroup>
<IAIInput
@ -118,6 +147,24 @@ const ModelList = (props: ModelListProps) => {
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
</Flex>
</Flex>
</Flex>
@ -126,12 +173,13 @@ const ModelList = (props: ModelListProps) => {
export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
data: EntityState<T> | undefined,
model_type: ModelType,
model_format: ModelFormat | undefined,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
@ -141,9 +189,11 @@ const modelsFilter = (
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
const matchesFormat =
model_format === undefined || model.model_format === model_format;
const matchesType = model.model_type === model_type;
if (matchesFilter && matchesFormat) {
if (matchesFilter && matchesFormat && matchesType) {
filteredModels.push(model);
}
});

View File

@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import {
MainModelConfigEntity,
LoRAModelConfigEntity,
useDeleteMainModelsMutation,
useDeleteLoRAModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = {
model: MainModelConfigEntity;
model: MainModelConfigEntity | LoRAModelConfigEntity;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};
const modelBaseTypeMap = {
'sd-1': 'SD1',
'sd-2': 'SD2',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation();
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
const { model, isSelected, setSelectedModelId } = props;
@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) {
}, [model.id, setSelectedModelId]);
const handleModelDelete = useCallback(() => {
deleteMainModel(model)
const method = { main: deleteMainModel, lora: deleteLoRAModel }[
model.model_type
];
method(model)
.unwrap()
.then((_) => {
dispatch(
@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) {
title: `${t('modelManager.modelDeleteFailed')}: ${
model.model_name
}`,
status: 'success',
status: 'error',
})
)
);
}
});
setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]);
}, [
deleteMainModel,
deleteLoRAModel,
model,
setSelectedModelId,
dispatch,
t,
]);
return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) {
<Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{
modelBaseTypeMap[
model.base_model as keyof typeof modelBaseTypeMap
MODEL_TYPE_SHORT_MAP[
model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP
]
}
</Badge>

View File

@ -52,9 +52,17 @@ type UpdateMainModelArg = {
body: MainModelConfig;
};
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse = UpdateMainModelResponse;
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
@ -62,6 +70,10 @@ type DeleteMainModelArg = {
type DeleteMainModelResponse = void;
type DeleteLoRAModelArg = DeleteMainModelArg;
type DeleteLoRAModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
@ -320,6 +332,31 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
updateLoRAModels: build.mutation<
UpdateLoRAModelResponse,
UpdateLoRAModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation<
DeleteLoRAModelResponse,
DeleteLoRAModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<
EntityState<ControlNetModelConfigEntity>,
void
@ -467,6 +504,8 @@ export const {
useAddMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,

View File

@ -5562,12 +5562,6 @@ export type components = {
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
@ -5580,6 +5574,12 @@ export type components = {
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;

View File

@ -42,8 +42,13 @@ export type ControlField = components['schemas']['ControlField'];
// Model Configs
export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
export type VaeModelConfig = components['schemas']['VaeModelConfig'];
export type ControlNetModelCheckpointConfig =
components['schemas']['ControlNetModelCheckpointConfig'];
export type ControlNetModelDiffusersConfig =
components['schemas']['ControlNetModelDiffusersConfig'];
export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig'];
| ControlNetModelCheckpointConfig
| ControlNetModelDiffusersConfig;
export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type DiffusersModelConfig =

View File

@ -1 +1 @@
__version__ = "3.0.1rc1"
__version__ = "3.0.1rc2"

View File

@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "InvokeAI"
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
requires-python = ">=3.9, <3.11"
requires-python = ">=3.9, <3.12"
readme = { content-type = "text/markdown", file = "README.md" }
keywords = ["stable-diffusion", "AI"]
dynamic = ["version"]
@ -32,16 +32,16 @@ classifiers = [
'Topic :: Scientific/Engineering :: Image Processing',
]
dependencies = [
"accelerate~=0.16",
"accelerate~=0.21.0",
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.0",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel~=2.0.0",
"controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.18.1",
"dnspython==2.2.1",
"diffusers[torch]~=0.19.0",
"dnspython~=2.4.0",
"dynamicprompts",
"easing-functions",
"einops",
@ -54,37 +54,37 @@ dependencies = [
"flask_cors==3.0.10",
"flask_socketio==5.3.0",
"flaskwebgui==1.0.3",
"gfpgan==1.3.8",
"huggingface-hub>=0.11.1",
"invisible-watermark>=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model
"npyscreen",
"numpy<1.24",
"numpy==1.24.4",
"omegaconf",
"opencv-python",
"picklescan",
"pillow",
"prompt-toolkit",
"pympler==1.0.1",
"pydantic==1.10.10",
"pympler~=1.0.1",
"pypatchmatch",
'pyperclip',
"pyreadline3",
"python-multipart==0.0.6",
"pytorch-lightning==1.7.7",
"python-multipart",
"pytorch-lightning",
"realesrgan",
"requests==2.28.2",
"requests~=2.28.2",
"rich~=13.3",
"safetensors~=0.3.0",
"scikit-image>=0.19",
"scikit-image~=0.21.0",
"send2trash",
"test-tube>=0.7.5",
"torch~=2.0.0",
"torchvision>=0.14.1",
"torchmetrics==0.11.4",
"torchsde==0.2.5",
"test-tube~=0.7.5",
"torch~=2.0.1",
"torchvision~=0.15.2",
"torchmetrics~=1.0.1",
"torchsde~=0.2.5",
"transformers~=4.31.0",
"uvicorn[standard]==0.21.1",
"uvicorn[standard]~=0.21.1",
"windows-curses; sys_platform=='win32'",
]

View File

@ -1,8 +1,16 @@
#!/bin/env python
import argparse
import sys
from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe
info = ModelProbe().probe(Path(sys.argv[1]))
parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
"model_path",
type=Path,
)
args = parser.parse_args()
info = ModelProbe().probe(args.model_path)
print(info)