From 7b6159f8d65a3876ef7ecc8e8a31dacf9a50d213 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 16 Jul 2023 02:12:01 +1000 Subject: [PATCH 1/4] feat(nodes): emit model loading events - remove dependency on having access to a `node` during emits, would need a bit of additional args passed through the system and I don't think its necessary at this point. this also allowed us to drop an extraneous fetching/parsing of the session from db. - provide the invocation context to all `get_model()` calls, so the events are able to be emitted - test all model loading events in the app and confirm socket events are received --- invokeai/app/invocations/latent.py | 15 +++++++------- invokeai/app/services/events.py | 12 +++-------- .../app/services/model_manager_service.py | 20 +++++-------------- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index baf78c7c23..f9844c5932 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -76,7 +76,7 @@ def get_scheduler( scheduler_name, SCHEDULER_MAP['ddim'] ) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict() + **scheduler_info.dict(), context=context, ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -262,6 +262,7 @@ class TextToLatentsInvocation(BaseInvocation): model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, + context=context, ) ) @@ -313,14 +314,14 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}) + **lora.dict(exclude={"weight"}), context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context, ) with ExitStack() as exit_stack,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ @@ -403,14 +404,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}) + **lora.dict(exclude={"weight"}), context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context, ) with ExitStack() as exit_stack,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ @@ -491,7 +492,7 @@ class LatentsToImageInvocation(BaseInvocation): latents = context.services.latents.get(self.latents.latents_name) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), + **self.vae.vae.dict(), context=context, ) with vae_info as vae: @@ -636,7 +637,7 @@ class ImageToLatentsInvocation(BaseInvocation): #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), + **self.vae.vae.dict(), context=context, ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 6c516c9b74..30d1b5e7a9 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -105,8 +105,6 @@ class EventServiceBase: def emit_model_load_started ( self, graph_execution_state_id: str, - node: dict, - source_node_id: str, model_name: str, base_model: BaseModelType, model_type: ModelType, @@ -117,8 +115,6 @@ class EventServiceBase: event_name="model_load_started", payload=dict( graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -129,8 +125,6 @@ class EventServiceBase: def emit_model_load_completed( self, graph_execution_state_id: str, - node: dict, - source_node_id: str, model_name: str, base_model: BaseModelType, model_type: ModelType, @@ -142,12 +136,12 @@ class EventServiceBase: event_name="model_load_completed", payload=dict( graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info, + hash=model_info.hash, + location=model_info.location, + precision=str(model_info.precision), ), ) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 67db5c9478..51ba1199fb 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -338,7 +338,6 @@ class ModelManagerService(ModelManagerServiceBase): base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, ) -> ModelInfo: """ @@ -346,11 +345,9 @@ class ModelManagerService(ModelManagerServiceBase): part (such as the vae) of a diffusers mode. """ - # if we are called from within a node, then we get to emit - # load start and complete events - if node and context: + # we can emit model loading events if we are executing with access to the invocation context + if context: self._emit_load_event( - node=node, context=context, model_name=model_name, base_model=base_model, @@ -365,9 +362,8 @@ class ModelManagerService(ModelManagerServiceBase): submodel, ) - if node and context: + if context: self._emit_load_event( - node=node, context=context, model_name=model_name, base_model=base_model, @@ -509,23 +505,19 @@ class ModelManagerService(ModelManagerServiceBase): def _emit_load_event( self, - node, context, model_name: str, base_model: BaseModelType, model_type: ModelType, - submodel: SubModelType, + submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): if context.services.queue.is_canceled(context.graph_execution_state_id): raise CanceledException() - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[node.id] + if model_info: context.services.events.emit_model_load_completed( graph_execution_state_id=context.graph_execution_state_id, - node=node.dict(), - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -535,8 +527,6 @@ class ModelManagerService(ModelManagerServiceBase): else: context.services.events.emit_model_load_started( graph_execution_state_id=context.graph_execution_state_id, - node=node.dict(), - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, From c487166d9ca0b868ee39125a53e010aa62ed7777 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 16 Jul 2023 02:26:30 +1000 Subject: [PATCH 2/4] feat(ui): add listeners for model load events - currently only exposed as DEBUG-level logs --- .../middleware/listenerMiddleware/index.ts | 4 +++ .../socketio/socketModelLoadCompleted.ts | 28 +++++++++++++++ .../socketio/socketModelLoadStarted.ts | 28 +++++++++++++++ .../frontend/web/src/services/api/types.d.ts | 1 + .../web/src/services/events/actions.ts | 34 +++++++++++++++++++ .../frontend/web/src/services/events/types.ts | 29 +++++++++++++++- .../services/events/util/setEventListeners.ts | 28 ++++++++++++++- 7 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadCompleted.ts create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadStarted.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index edeb156439..94044d3cc5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -88,6 +88,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; +import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted'; +import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted'; export const listenerMiddleware = createListenerMiddleware(); @@ -177,6 +179,8 @@ addSocketConnectedListener(); addSocketDisconnectedListener(); addSocketSubscribedListener(); addSocketUnsubscribedListener(); +addModelLoadStartedEventListener(); +addModelLoadCompletedEventListener(); // Session Created addSessionCreatedPendingListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadCompleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadCompleted.ts new file mode 100644 index 0000000000..bc533b9178 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadCompleted.ts @@ -0,0 +1,28 @@ +import { log } from 'app/logging/useLogger'; +import { + appSocketModelLoadCompleted, + socketModelLoadCompleted, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addModelLoadCompletedEventListener = () => { + startAppListening({ + actionCreator: socketModelLoadCompleted, + effect: (action, { dispatch, getState }) => { + const { model_name, model_type, submodel } = action.payload.data; + + let modelString = `${model_type} model: ${model_name}`; + + if (submodel) { + modelString = modelString.concat(`, submodel: ${submodel}`); + } + + moduleLog.debug(action.payload, `Model load completed (${modelString})`); + + // pass along the socket event as an application action + dispatch(appSocketModelLoadCompleted(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadStarted.ts new file mode 100644 index 0000000000..aa53a70cee --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoadStarted.ts @@ -0,0 +1,28 @@ +import { log } from 'app/logging/useLogger'; +import { + appSocketModelLoadStarted, + socketModelLoadStarted, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addModelLoadStartedEventListener = () => { + startAppListening({ + actionCreator: socketModelLoadStarted, + effect: (action, { dispatch, getState }) => { + const { model_name, model_type, submodel } = action.payload.data; + + let modelString = `${model_type} model: ${model_name}`; + + if (submodel) { + modelString = modelString.concat(`, submodel: ${submodel}`); + } + + moduleLog.debug(action.payload, `Model load started (${modelString})`); + + // pass along the socket event as an application action + dispatch(appSocketModelLoadStarted(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 37faae592f..dce7f5c28b 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -28,6 +28,7 @@ export type OffsetPaginatedResults_ImageDTO_ = // Models export type ModelType = components['schemas']['ModelType']; +export type SubModelType = components['schemas']['SubModelType']; export type BaseModelType = components['schemas']['BaseModelType']; export type MainModelField = components['schemas']['MainModelField']; export type VAEModelField = components['schemas']['VAEModelField']; diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index ed154b9cd8..cb2d665748 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -5,6 +5,8 @@ import { InvocationCompleteEvent, InvocationErrorEvent, InvocationStartedEvent, + ModelLoadCompletedEvent, + ModelLoadStartedEvent, } from 'services/events/types'; // Common socket action payload data @@ -162,3 +164,35 @@ export const socketGeneratorProgress = createAction< export const appSocketGeneratorProgress = createAction< BaseSocketPayload & { data: GeneratorProgressEvent } >('socket/appSocketGeneratorProgress'); + +/** + * Socket.IO Model Load Started + * + * Do not use. Only for use in middleware. + */ +export const socketModelLoadStarted = createAction< + BaseSocketPayload & { data: ModelLoadStartedEvent } +>('socket/socketModelLoadStarted'); + +/** + * App-level Model Load Started + */ +export const appSocketModelLoadStarted = createAction< + BaseSocketPayload & { data: ModelLoadStartedEvent } +>('socket/appSocketModelLoadStarted'); + +/** + * Socket.IO Model Load Started + * + * Do not use. Only for use in middleware. + */ +export const socketModelLoadCompleted = createAction< + BaseSocketPayload & { data: ModelLoadCompletedEvent } +>('socket/socketModelLoadCompleted'); + +/** + * App-level Model Load Completed + */ +export const appSocketModelLoadCompleted = createAction< + BaseSocketPayload & { data: ModelLoadCompletedEvent } +>('socket/appSocketModelLoadCompleted'); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index f589fdd6cc..ec1b55e3fe 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -1,5 +1,11 @@ import { O } from 'ts-toolbelt'; -import { Graph, GraphExecutionState } from '../api/types'; +import { + BaseModelType, + Graph, + GraphExecutionState, + ModelType, + SubModelType, +} from '../api/types'; /** * A progress image, we get one for each step in the generation @@ -25,6 +31,25 @@ export type BaseNode = { [key: string]: AnyInvocation[keyof AnyInvocation]; }; +export type ModelLoadStartedEvent = { + graph_execution_state_id: string; + model_name: string; + base_model: BaseModelType; + model_type: ModelType; + submodel: SubModelType; +}; + +export type ModelLoadCompletedEvent = { + graph_execution_state_id: string; + model_name: string; + base_model: BaseModelType; + model_type: ModelType; + submodel: SubModelType; + hash?: string; + location: string; + precision: string; +}; + /** * A `generator_progress` socket.io event. * @@ -101,6 +126,8 @@ export type ServerToClientEvents = { graph_execution_state_complete: ( payload: GraphExecutionStateCompleteEvent ) => void; + model_load_started: (payload: ModelLoadStartedEvent) => void; + model_load_completed: (payload: ModelLoadCompletedEvent) => void; }; export type ClientToServerEvents = { diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 62b5864185..44d3301556 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -11,6 +11,8 @@ import { socketConnected, socketDisconnected, socketSubscribed, + socketModelLoadStarted, + socketModelLoadCompleted, } from '../actions'; import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { Logger } from 'roarr'; @@ -44,7 +46,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => { socketSubscribed({ sessionId, timestamp: getTimestamp(), - boardId: getState().boards.selectedBoardId, + boardId: getState().gallery.selectedBoardId, }) ); } @@ -118,4 +120,28 @@ export const setEventListeners = (arg: SetEventListenersArg) => { }) ); }); + + /** + * Model load started + */ + socket.on('model_load_started', (data) => { + dispatch( + socketModelLoadStarted({ + data, + timestamp: getTimestamp(), + }) + ); + }); + + /** + * Model load completed + */ + socket.on('model_load_completed', (data) => { + dispatch( + socketModelLoadCompleted({ + data, + timestamp: getTimestamp(), + }) + ); + }); }; From ba1284968528896896cc438b9e5ca1655ed2dafc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:16:55 +1000 Subject: [PATCH 3/4] fix(nodes): fix some model load events not emitting Missed adding the `context` arg to them initially --- invokeai/app/invocations/compel.py | 5 +++-- invokeai/app/invocations/generate.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index a5a9701149..c8a9bf4464 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -57,10 +57,10 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.dict(), + **self.clip.tokenizer.dict(), context=context, ) text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.dict(), + **self.clip.text_encoder.dict(), context=context, ) def _lora_loader(): @@ -82,6 +82,7 @@ class CompelInvocation(BaseInvocation): model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, + context=context, ).context.model ) except ModelNotFoundException: diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 6cdb83effc..b8e9ec2038 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -157,13 +157,13 @@ class InpaintInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context,) yield (lora_info.context.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,) + vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,) with vae_info as vae,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ From af9e8fefce22bf7fd2946ef211375bcea4c25317 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:35:20 +1000 Subject: [PATCH 4/4] feat(ui): socket event timestamps have ms precision --- invokeai/frontend/web/src/common/util/getTimestamp.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/common/util/getTimestamp.ts b/invokeai/frontend/web/src/common/util/getTimestamp.ts index 570283fa8f..daa9f8dc33 100644 --- a/invokeai/frontend/web/src/common/util/getTimestamp.ts +++ b/invokeai/frontend/web/src/common/util/getTimestamp.ts @@ -3,4 +3,5 @@ import dateFormat from 'dateformat'; /** * Get a `now` timestamp with 1s precision, formatted as ISO datetime. */ -export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime'); +export const getTimestamp = () => + dateFormat(new Date(), `yyyy-mm-dd'T'HH:MM:ss:lo`);