mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: model events (#3786)
[feat(nodes): emit model loading events](7b6159f8d6
) - remove dependency on having access to a `node` during emits, would need a bit of additional args passed through the system and I don't think its necessary at this point. this also allowed us to drop an extraneous fetching/parsing of the session from db. - provide the invocation context to all `get_model()` calls, so the events are able to be emitted - test all model loading events in the app and confirm socket events are received [feat(ui): add listeners for model load events](c487166d9c
) - currently only exposed as DEBUG-level logs --- One change I missed in the commit messages is the `ModelInfo` class is not serializable, so I split out the pieces of information we didn't already have (hash, location, precision) and added them to the event payload directly.
This commit is contained in:
commit
0edb31febd
@ -57,10 +57,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(), context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
@ -82,6 +82,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
|
@ -157,13 +157,13 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context,)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
||||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
||||||
|
|
||||||
with vae_info as vae,\
|
with vae_info as vae,\
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
|
@ -76,7 +76,7 @@ def get_scheduler(
|
|||||||
scheduler_name, SCHEDULER_MAP['ddim']
|
scheduler_name, SCHEDULER_MAP['ddim']
|
||||||
)
|
)
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict()
|
**scheduler_info.dict(), context=context,
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
model_name=control_info.control_model.model_name,
|
model_name=control_info.control_model.model_name,
|
||||||
model_type=ModelType.ControlNet,
|
model_type=ModelType.ControlNet,
|
||||||
base_model=control_info.control_model.base_model,
|
base_model=control_info.control_model.base_model,
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"})
|
**lora.dict(exclude={"weight"}), context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack,\
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"})
|
**lora.dict(exclude={"weight"}), context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack,\
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
@ -105,8 +105,6 @@ class EventServiceBase:
|
|||||||
def emit_model_load_started (
|
def emit_model_load_started (
|
||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
@ -117,8 +115,6 @@ class EventServiceBase:
|
|||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -129,8 +125,6 @@ class EventServiceBase:
|
|||||||
def emit_model_load_completed(
|
def emit_model_load_completed(
|
||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
@ -142,12 +136,12 @@ class EventServiceBase:
|
|||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
hash=model_info.hash,
|
||||||
|
location=model_info.location,
|
||||||
|
precision=str(model_info.precision),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -339,7 +339,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
@ -347,11 +346,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if we are called from within a node, then we get to emit
|
# we can emit model loading events if we are executing with access to the invocation context
|
||||||
# load start and complete events
|
if context:
|
||||||
if node and context:
|
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
node=node,
|
|
||||||
context=context,
|
context=context,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -366,9 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel,
|
submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if node and context:
|
if context:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
node=node,
|
|
||||||
context=context,
|
context=context,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -510,23 +506,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
node,
|
|
||||||
context,
|
context,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: SubModelType,
|
submodel: Optional[SubModelType] = None,
|
||||||
model_info: Optional[ModelInfo] = None,
|
model_info: Optional[ModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
|
||||||
if model_info:
|
if model_info:
|
||||||
context.services.events.emit_model_load_completed(
|
context.services.events.emit_model_load_completed(
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -536,8 +528,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
context.services.events.emit_model_load_started(
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
node=node.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
|
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
|
||||||
|
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -177,6 +179,8 @@ addSocketConnectedListener();
|
|||||||
addSocketDisconnectedListener();
|
addSocketDisconnectedListener();
|
||||||
addSocketSubscribedListener();
|
addSocketSubscribedListener();
|
||||||
addSocketUnsubscribedListener();
|
addSocketUnsubscribedListener();
|
||||||
|
addModelLoadStartedEventListener();
|
||||||
|
addModelLoadCompletedEventListener();
|
||||||
|
|
||||||
// Session Created
|
// Session Created
|
||||||
addSessionCreatedPendingListener();
|
addSessionCreatedPendingListener();
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketModelLoadCompleted,
|
||||||
|
socketModelLoadCompleted,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addModelLoadCompletedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelLoadCompleted,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { model_name, model_type, submodel } = action.payload.data;
|
||||||
|
|
||||||
|
let modelString = `${model_type} model: ${model_name}`;
|
||||||
|
|
||||||
|
if (submodel) {
|
||||||
|
modelString = modelString.concat(`, submodel: ${submodel}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(action.payload, `Model load completed (${modelString})`);
|
||||||
|
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketModelLoadCompleted(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,28 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketModelLoadStarted,
|
||||||
|
socketModelLoadStarted,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addModelLoadStartedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelLoadStarted,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { model_name, model_type, submodel } = action.payload.data;
|
||||||
|
|
||||||
|
let modelString = `${model_type} model: ${model_name}`;
|
||||||
|
|
||||||
|
if (submodel) {
|
||||||
|
modelString = modelString.concat(`, submodel: ${submodel}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(action.payload, `Model load started (${modelString})`);
|
||||||
|
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketModelLoadStarted(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -3,4 +3,5 @@ import dateFormat from 'dateformat';
|
|||||||
/**
|
/**
|
||||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||||
*/
|
*/
|
||||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
export const getTimestamp = () =>
|
||||||
|
dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);
|
||||||
|
@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ =
|
|||||||
|
|
||||||
// Models
|
// Models
|
||||||
export type ModelType = components['schemas']['ModelType'];
|
export type ModelType = components['schemas']['ModelType'];
|
||||||
|
export type SubModelType = components['schemas']['SubModelType'];
|
||||||
export type BaseModelType = components['schemas']['BaseModelType'];
|
export type BaseModelType = components['schemas']['BaseModelType'];
|
||||||
export type MainModelField = components['schemas']['MainModelField'];
|
export type MainModelField = components['schemas']['MainModelField'];
|
||||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||||
|
@ -5,6 +5,8 @@ import {
|
|||||||
InvocationCompleteEvent,
|
InvocationCompleteEvent,
|
||||||
InvocationErrorEvent,
|
InvocationErrorEvent,
|
||||||
InvocationStartedEvent,
|
InvocationStartedEvent,
|
||||||
|
ModelLoadCompletedEvent,
|
||||||
|
ModelLoadStartedEvent,
|
||||||
} from 'services/events/types';
|
} from 'services/events/types';
|
||||||
|
|
||||||
// Common socket action payload data
|
// Common socket action payload data
|
||||||
@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction<
|
|||||||
export const appSocketGeneratorProgress = createAction<
|
export const appSocketGeneratorProgress = createAction<
|
||||||
BaseSocketPayload & { data: GeneratorProgressEvent }
|
BaseSocketPayload & { data: GeneratorProgressEvent }
|
||||||
>('socket/appSocketGeneratorProgress');
|
>('socket/appSocketGeneratorProgress');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Socket.IO Model Load Started
|
||||||
|
*
|
||||||
|
* Do not use. Only for use in middleware.
|
||||||
|
*/
|
||||||
|
export const socketModelLoadStarted = createAction<
|
||||||
|
BaseSocketPayload & { data: ModelLoadStartedEvent }
|
||||||
|
>('socket/socketModelLoadStarted');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* App-level Model Load Started
|
||||||
|
*/
|
||||||
|
export const appSocketModelLoadStarted = createAction<
|
||||||
|
BaseSocketPayload & { data: ModelLoadStartedEvent }
|
||||||
|
>('socket/appSocketModelLoadStarted');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Socket.IO Model Load Started
|
||||||
|
*
|
||||||
|
* Do not use. Only for use in middleware.
|
||||||
|
*/
|
||||||
|
export const socketModelLoadCompleted = createAction<
|
||||||
|
BaseSocketPayload & { data: ModelLoadCompletedEvent }
|
||||||
|
>('socket/socketModelLoadCompleted');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* App-level Model Load Completed
|
||||||
|
*/
|
||||||
|
export const appSocketModelLoadCompleted = createAction<
|
||||||
|
BaseSocketPayload & { data: ModelLoadCompletedEvent }
|
||||||
|
>('socket/appSocketModelLoadCompleted');
|
||||||
|
@ -1,5 +1,11 @@
|
|||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
import { Graph, GraphExecutionState } from '../api/types';
|
import {
|
||||||
|
BaseModelType,
|
||||||
|
Graph,
|
||||||
|
GraphExecutionState,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
} from '../api/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A progress image, we get one for each step in the generation
|
* A progress image, we get one for each step in the generation
|
||||||
@ -25,6 +31,25 @@ export type BaseNode = {
|
|||||||
[key: string]: AnyInvocation[keyof AnyInvocation];
|
[key: string]: AnyInvocation[keyof AnyInvocation];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ModelLoadStartedEvent = {
|
||||||
|
graph_execution_state_id: string;
|
||||||
|
model_name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
model_type: ModelType;
|
||||||
|
submodel: SubModelType;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ModelLoadCompletedEvent = {
|
||||||
|
graph_execution_state_id: string;
|
||||||
|
model_name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
model_type: ModelType;
|
||||||
|
submodel: SubModelType;
|
||||||
|
hash?: string;
|
||||||
|
location: string;
|
||||||
|
precision: string;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A `generator_progress` socket.io event.
|
* A `generator_progress` socket.io event.
|
||||||
*
|
*
|
||||||
@ -101,6 +126,8 @@ export type ServerToClientEvents = {
|
|||||||
graph_execution_state_complete: (
|
graph_execution_state_complete: (
|
||||||
payload: GraphExecutionStateCompleteEvent
|
payload: GraphExecutionStateCompleteEvent
|
||||||
) => void;
|
) => void;
|
||||||
|
model_load_started: (payload: ModelLoadStartedEvent) => void;
|
||||||
|
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ClientToServerEvents = {
|
export type ClientToServerEvents = {
|
||||||
|
@ -11,6 +11,8 @@ import {
|
|||||||
socketConnected,
|
socketConnected,
|
||||||
socketDisconnected,
|
socketDisconnected,
|
||||||
socketSubscribed,
|
socketSubscribed,
|
||||||
|
socketModelLoadStarted,
|
||||||
|
socketModelLoadCompleted,
|
||||||
} from '../actions';
|
} from '../actions';
|
||||||
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
||||||
import { Logger } from 'roarr';
|
import { Logger } from 'roarr';
|
||||||
@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
|||||||
socketSubscribed({
|
socketSubscribed({
|
||||||
sessionId,
|
sessionId,
|
||||||
timestamp: getTimestamp(),
|
timestamp: getTimestamp(),
|
||||||
boardId: getState().boards.selectedBoardId,
|
boardId: getState().gallery.selectedBoardId,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Model load started
|
||||||
|
*/
|
||||||
|
socket.on('model_load_started', (data) => {
|
||||||
|
dispatch(
|
||||||
|
socketModelLoadStarted({
|
||||||
|
data,
|
||||||
|
timestamp: getTimestamp(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Model load completed
|
||||||
|
*/
|
||||||
|
socket.on('model_load_completed', (data) => {
|
||||||
|
dispatch(
|
||||||
|
socketModelLoadCompleted({
|
||||||
|
data,
|
||||||
|
timestamp: getTimestamp(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user