diff --git a/frontend/appflowy_flutter/lib/shared/feature_flags.dart b/frontend/appflowy_flutter/lib/shared/feature_flags.dart index 8a1b230461..ff780e2647 100644 --- a/frontend/appflowy_flutter/lib/shared/feature_flags.dart +++ b/frontend/appflowy_flutter/lib/shared/feature_flags.dart @@ -91,7 +91,7 @@ enum FeatureFlag { bool get isOn { if ([ - // FeatureFlag.planBilling, + // if (kDebugMode) FeatureFlag.planBilling, // release this feature in version 0.6.1 FeatureFlag.spaceDesign, // release this feature in version 0.5.9 diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/local_ai_on_boarding_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/local_ai_on_boarding_bloc.dart new file mode 100644 index 0000000000..7a7d0fecd7 --- /dev/null +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/local_ai_on_boarding_bloc.dart @@ -0,0 +1,69 @@ +import 'package:appflowy_backend/dispatch/dispatch.dart'; +import 'package:appflowy_backend/log.dart'; +import 'package:appflowy_backend/protobuf/flowy-error/errors.pb.dart'; +import 'package:appflowy_backend/protobuf/flowy-user/workspace.pb.dart'; +import 'package:appflowy_result/appflowy_result.dart'; +import 'package:bloc/bloc.dart'; +import 'package:freezed_annotation/freezed_annotation.dart'; + +part 'local_ai_on_boarding_bloc.freezed.dart'; + +class LocalAIOnBoardingBloc + extends Bloc { + LocalAIOnBoardingBloc(this.workspaceId) + : super(const LocalAIOnBoardingState()) { + _dispatch(); + } + + final String workspaceId; + + void _dispatch() { + on((event, emit) { + event.when( + started: () { + _loadSubscriptionPlans(); + }, + didGetSubscriptionPlans: (result) { + result.fold( + (workspaceSubInfo) { + final isPurchaseAILocal = workspaceSubInfo.addOns.any((addOn) { + return addOn.type == WorkspaceAddOnPBType.AddOnAiLocal; + }); + + emit( + state.copyWith(isPurchaseAILocal: isPurchaseAILocal), + ); + }, + (err) { + Log.error("Failed to get subscription plans: $err"); + }, + ); + }, + ); + }); + } + + void _loadSubscriptionPlans() { + final payload = UserWorkspaceIdPB()..workspaceId = workspaceId; + UserEventGetWorkspaceSubscriptionInfo(payload).send().then((result) { + if (!isClosed) { + add(LocalAIOnBoardingEvent.didGetSubscriptionPlans(result)); + } + }); + } +} + +@freezed +class LocalAIOnBoardingEvent with _$LocalAIOnBoardingEvent { + const factory LocalAIOnBoardingEvent.started() = _Started; + const factory LocalAIOnBoardingEvent.didGetSubscriptionPlans( + FlowyResult result, + ) = _LoadSubscriptionPlans; +} + +@freezed +class LocalAIOnBoardingState with _$LocalAIOnBoardingState { + const factory LocalAIOnBoardingState({ + @Default(false) bool isPurchaseAILocal, + }) = _LocalAIOnBoardingState; +} diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/plugin_state_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/plugin_state_bloc.dart index 4ca9522208..2e73696b2a 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/plugin_state_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/plugin_state_bloc.dart @@ -1,5 +1,6 @@ import 'dart:async'; +import 'package:appflowy/core/helpers/url_launcher.dart'; import 'package:appflowy/workspace/application/settings/ai/local_llm_listener.dart'; import 'package:appflowy_backend/dispatch/dispatch.dart'; import 'package:appflowy_backend/log.dart'; @@ -67,8 +68,20 @@ class PluginStateBloc extends Bloc { break; } }, - restartLocalAI: () { - ChatEventRestartLocalAIChat().send(); + restartLocalAI: () async { + emit( + const PluginStateState(action: PluginStateAction.loadingPlugin()), + ); + unawaited(ChatEventRestartLocalAIChat().send()); + }, + openModelDirectory: () async { + final result = await ChatEventGetModelStorageDirectory().send(); + result.fold( + (data) { + afLaunchUrl(Uri.file(data.filePath)); + }, + (err) => Log.error(err.toString()), + ); }, ); } @@ -80,12 +93,15 @@ class PluginStateEvent with _$PluginStateEvent { const factory PluginStateEvent.updateState(LocalAIPluginStatePB pluginState) = _UpdatePluginState; const factory PluginStateEvent.restartLocalAI() = _RestartLocalAI; + const factory PluginStateEvent.openModelDirectory() = + _OpenModelStorageDirectory; } @freezed class PluginStateState with _$PluginStateState { - const factory PluginStateState({required PluginStateAction action}) = - _PluginStateState; + const factory PluginStateState({ + required PluginStateAction action, + }) = _PluginStateState; } @freezed diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/footer/sidebar_footer.dart b/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/footer/sidebar_footer.dart index 7825023263..c22bc95978 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/footer/sidebar_footer.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/home/menu/sidebar/footer/sidebar_footer.dart @@ -1,5 +1,6 @@ import 'package:appflowy/shared/feature_flags.dart'; import 'package:appflowy/workspace/presentation/home/menu/sidebar/footer/sidebar_toast.dart'; +import 'package:appflowy/workspace/presentation/settings/widgets/setting_appflowy_cloud.dart'; import 'package:flutter/material.dart'; import 'package:appflowy/generated/flowy_svgs.g.dart'; @@ -19,7 +20,12 @@ class SidebarFooter extends StatelessWidget { Widget build(BuildContext context) { return Column( children: [ - if (FeatureFlag.planBilling.isOn) const SidebarToast(), + if (FeatureFlag.planBilling.isOn) + BillingGateGuard( + builder: (context) { + return const SidebarToast(); + }, + ), const Row( children: [ Expanded(child: SidebarTrashButton()), diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/downloading.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/downloading_model.dart similarity index 87% rename from frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/downloading.dart rename to frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/downloading_model.dart index 08611542a4..a9bfb6d2d4 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/downloading.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/downloading_model.dart @@ -37,15 +37,10 @@ class DownloadingIndicator extends StatelessWidget { color: Theme.of(context).colorScheme.surfaceContainerHighest, borderRadius: BorderRadius.circular(8), ), - child: Padding( - padding: const EdgeInsets.all(12.0), - child: Column( - children: [ - // const DownloadingPrompt(), - // const VSpace(12), - DownloadingProgressBar(onCancel: onCancel), - ], - ), + child: Column( + children: [ + DownloadingProgressBar(onCancel: onCancel), + ], ), ), ), @@ -65,9 +60,12 @@ class DownloadingProgressBar extends StatelessWidget { return Column( crossAxisAlignment: CrossAxisAlignment.start, children: [ - FlowyText( - "${LocaleKeys.settings_aiPage_keys_downloadingModel.tr()}: ${state.object}", - fontSize: 11, + Opacity( + opacity: 0.6, + child: FlowyText( + "${LocaleKeys.settings_aiPage_keys_downloadingModel.tr()}: ${state.object}", + fontSize: 11, + ), ), IntrinsicHeight( child: Row( diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_chat_setting.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_chat_setting.dart index 98f48df89c..46e9e1cd48 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_chat_setting.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_chat_setting.dart @@ -1,7 +1,7 @@ import 'package:appflowy/generated/flowy_svgs.g.dart'; import 'package:appflowy/workspace/application/settings/ai/local_ai_chat_bloc.dart'; import 'package:appflowy/workspace/application/settings/ai/local_ai_chat_toggle_bloc.dart'; -import 'package:appflowy/workspace/presentation/settings/pages/setting_ai_view/downloading.dart'; +import 'package:appflowy/workspace/presentation/settings/pages/setting_ai_view/downloading_model.dart'; import 'package:appflowy/workspace/presentation/settings/pages/setting_ai_view/init_local_ai.dart'; import 'package:appflowy/workspace/presentation/settings/pages/setting_ai_view/plugin_state.dart'; import 'package:appflowy/workspace/presentation/widgets/dialogs.dart'; @@ -70,7 +70,7 @@ class LocalAIChatSetting extends StatelessWidget { header: const LocalAIChatSettingHeader(), collapsed: const SizedBox.shrink(), expanded: Padding( - padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 6), + padding: const EdgeInsets.symmetric(vertical: 6), child: Column( crossAxisAlignment: CrossAxisAlignment.start, children: [ @@ -240,7 +240,7 @@ class _LocalLLMInfoWidget extends StatelessWidget { ); }, finishDownload: () => const InitLocalAIIndicator(), - checkPluginState: () => const CheckPluginStateIndicator(), + checkPluginState: () => const PluginStateIndicator(), ); return Padding( @@ -253,9 +253,12 @@ class _LocalLLMInfoWidget extends StatelessWidget { } else { return Opacity( opacity: 0.5, - child: FlowyText( - error.msg, - maxLines: 10, + child: Padding( + padding: const EdgeInsets.symmetric(vertical: 6), + child: FlowyText( + error.msg, + maxLines: 10, + ), ), ); } diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_setting.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_setting.dart index 8f58d20b17..26bd0d8426 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_setting.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/local_ai_setting.dart @@ -6,6 +6,7 @@ import 'package:appflowy/workspace/presentation/widgets/toggle/toggle.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:expandable/expandable.dart'; import 'package:flowy_infra_ui/style_widget/text.dart'; +import 'package:flowy_infra_ui/widget/spacing.dart'; import 'package:flutter/material.dart'; import 'package:appflowy/workspace/application/settings/ai/settings_ai_bloc.dart'; @@ -55,6 +56,7 @@ class _LocalAISettingState extends State { collapsed: const SizedBox.shrink(), expanded: Column( children: [ + const VSpace(6), DecoratedBox( decoration: BoxDecoration( color: Theme.of(context) @@ -64,11 +66,8 @@ class _LocalAISettingState extends State { const BorderRadius.all(Radius.circular(4)), ), child: const Padding( - padding: EdgeInsets.only( - left: 12.0, - top: 6, - bottom: 6, - ), + padding: + EdgeInsets.symmetric(horizontal: 12, vertical: 6), child: LocalAIChatSetting(), ), ), diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/plugin_state.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/plugin_state.dart index b06fb4ae39..a8af9db11d 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/plugin_state.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/plugin_state.dart @@ -8,8 +8,8 @@ import 'package:flowy_infra_ui/widget/spacing.dart'; import 'package:flutter/material.dart'; import 'package:flutter_bloc/flutter_bloc.dart'; -class CheckPluginStateIndicator extends StatelessWidget { - const CheckPluginStateIndicator({super.key}); +class PluginStateIndicator extends StatelessWidget { + const PluginStateIndicator({super.key}); @override Widget build(BuildContext context) { @@ -20,7 +20,7 @@ class CheckPluginStateIndicator extends StatelessWidget { builder: (context, state) { return state.action.when( init: () => const _InitPlugin(), - ready: () => const _ReadyToUse(), + ready: () => const _LocalAIReadyToUse(), restart: () => const _ReloadButton(), loadingPlugin: () => const _InitPlugin(), ); @@ -74,8 +74,8 @@ class _ReloadButton extends StatelessWidget { } } -class _ReadyToUse extends StatelessWidget { - const _ReadyToUse(); +class _LocalAIReadyToUse extends StatelessWidget { + const _LocalAIReadyToUse(); @override Widget build(BuildContext context) { @@ -87,7 +87,7 @@ class _ReadyToUse extends StatelessWidget { ), ), child: Padding( - padding: const EdgeInsets.symmetric(vertical: 8), + padding: const EdgeInsets.symmetric(vertical: 4), child: Row( children: [ const HSpace(8), @@ -101,6 +101,23 @@ class _ReadyToUse extends StatelessWidget { fontSize: 11, color: const Color(0xFF1E4620), ), + const Spacer(), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 6), + child: FlowyButton( + useIntrinsicWidth: true, + text: FlowyText( + LocaleKeys.settings_aiPage_keys_openModelDirectory.tr(), + fontSize: 11, + color: const Color(0xFF1E4620), + ), + onTap: () { + context.read().add( + const PluginStateEvent.openModelDirectory(), + ); + }, + ), + ), ], ), ), diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/settings_ai_view.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/settings_ai_view.dart index af6a6ed2d1..09b6eef77c 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/settings_ai_view.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/settings_ai_view.dart @@ -1,4 +1,12 @@ +import 'package:appflowy/generated/flowy_svgs.g.dart'; +import 'package:appflowy/shared/feature_flags.dart'; +import 'package:appflowy/workspace/application/settings/ai/local_ai_on_boarding_bloc.dart'; +import 'package:appflowy/workspace/application/settings/settings_dialog_bloc.dart'; +import 'package:appflowy/workspace/presentation/settings/pages/setting_ai_view/local_ai_setting.dart'; import 'package:appflowy/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart'; +import 'package:appflowy/workspace/presentation/settings/widgets/setting_appflowy_cloud.dart'; +import 'package:flowy_infra/theme_extension.dart'; +import 'package:flowy_infra_ui/widget/spacing.dart'; import 'package:flutter/material.dart'; import 'package:appflowy/generated/locale_keys.g.dart'; @@ -43,9 +51,12 @@ class SettingsAIView extends StatelessWidget { const AIModelSelection(), ]; - // children.add(const LocalAISetting()); - children.add(const _AISearchToggle(value: false)); + children.add( + _LocalAIOnBoarding( + workspaceId: userProfile.workspaceId, + ), + ); return SettingsBody( title: LocaleKeys.settings_aiPage_title.tr(), @@ -101,3 +112,113 @@ class _AISearchToggle extends StatelessWidget { ); } } + +class _LocalAIOnBoarding extends StatelessWidget { + const _LocalAIOnBoarding({required this.workspaceId}); + final String workspaceId; + + @override + Widget build(BuildContext context) { + if (FeatureFlag.planBilling.isOn) { + return BillingGateGuard( + builder: (context) { + return BlocProvider( + create: (context) => LocalAIOnBoardingBloc(workspaceId) + ..add(const LocalAIOnBoardingEvent.started()), + child: BlocBuilder( + builder: (context, state) { + // Show the local AI settings if the user has purchased the AI Local plan + if (state.isPurchaseAILocal) { + return const LocalAISetting(); + } else { + // Show the upgrade to AI Local plan button if the user has not purchased the AI Local plan + return _UpgradeToAILocalPlan( + onTap: () { + context.read().add( + const SettingsDialogEvent.setSelectedPage( + SettingsPage.plan, + ), + ); + }, + ); + } + }, + ), + ); + }, + ); + } else { + return const SizedBox.shrink(); + } + } +} + +class _UpgradeToAILocalPlan extends StatefulWidget { + const _UpgradeToAILocalPlan({required this.onTap}); + + final VoidCallback onTap; + + @override + State<_UpgradeToAILocalPlan> createState() => _UpgradeToAILocalPlanState(); +} + +class _UpgradeToAILocalPlanState extends State<_UpgradeToAILocalPlan> { + bool _isHovered = false; + + @override + Widget build(BuildContext context) { + const textGradient = LinearGradient( + begin: Alignment.bottomLeft, + end: Alignment.bottomRight, + colors: [Color(0xFF8032FF), Color(0xFFEF35FF)], + stops: [0.1545, 0.8225], + ); + + final backgroundGradient = LinearGradient( + begin: Alignment.topLeft, + end: Alignment.bottomRight, + colors: [ + _isHovered + ? const Color(0xFF8032FF).withOpacity(0.3) + : Colors.transparent, + _isHovered + ? const Color(0xFFEF35FF).withOpacity(0.3) + : Colors.transparent, + ], + ); + + return GestureDetector( + onTap: widget.onTap, + child: MouseRegion( + cursor: SystemMouseCursors.click, + onEnter: (_) => setState(() => _isHovered = true), + onExit: (_) => setState(() => _isHovered = false), + child: Container( + padding: const EdgeInsets.symmetric(vertical: 8, horizontal: 10), + clipBehavior: Clip.antiAlias, + decoration: BoxDecoration( + gradient: backgroundGradient, + borderRadius: BorderRadius.circular(10), + ), + child: Row( + children: [ + const FlowySvg( + FlowySvgs.upgrade_storage_s, + blendMode: null, + ), + const HSpace(6), + ShaderMask( + shaderCallback: (bounds) => textGradient.createShader(bounds), + blendMode: BlendMode.srcIn, + child: FlowyText( + LocaleKeys.sideBar_upgradeToAILocal.tr(), + color: AFThemeExtension.of(context).strongText, + ), + ), + ], + ), + ), + ), + ); + } +} diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/widgets/setting_appflowy_cloud.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/widgets/setting_appflowy_cloud.dart index 8716c8ca89..c6057094bb 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/widgets/setting_appflowy_cloud.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/widgets/setting_appflowy_cloud.dart @@ -1,3 +1,5 @@ +import 'package:appflowy_backend/log.dart'; +import 'package:flutter/foundation.dart'; import 'package:flutter/gestures.dart'; import 'package:flutter/material.dart'; @@ -327,3 +329,51 @@ class AppFlowyCloudEnableSync extends StatelessWidget { ); } } + +class BillingGateGuard extends StatelessWidget { + const BillingGateGuard({required this.builder, super.key}); + + final Widget Function(BuildContext context) builder; + + @override + Widget build(BuildContext context) { + return FutureBuilder( + future: isBillingEnabled(), + builder: (context, snapshot) { + final isBillingEnabled = snapshot.data ?? false; + if (isBillingEnabled && + snapshot.connectionState == ConnectionState.done) { + return builder(context); + } + + // If the billing is not enabled, show nothing + return const SizedBox.shrink(); + }, + ); + } +} + +Future isBillingEnabled() async { + final result = await UserEventGetCloudConfig().send(); + return result.fold((cloudSetting) { + final whiteList = [ + "https://beta.appflowy.cloud", + "https://test.appflowy.cloud", + ]; + if (kDebugMode) { + whiteList.add("http://localhost:8000"); + } + + if (whiteList.contains(cloudSetting.serverUrl)) { + return true; + } else { + Log.warn( + "Billing is not enabled for this server:${cloudSetting.serverUrl}", + ); + return false; + } + }, (err) { + Log.error("Failed to get cloud config: $err"); + return false; + }); +} diff --git a/frontend/appflowy_tauri/src-tauri/Cargo.lock b/frontend/appflowy_tauri/src-tauri/Cargo.lock index 88ce19eca4..a30c5132ca 100644 --- a/frontend/appflowy_tauri/src-tauri/Cargo.lock +++ b/frontend/appflowy_tauri/src-tauri/Cargo.lock @@ -206,7 +206,7 @@ dependencies = [ [[package]] name = "appflowy-local-ai" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=c4ab1db44e96348f9b0770dd8ecc990f68ac415d#c4ab1db44e96348f9b0770dd8ecc990f68ac415d" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=f3b678e36f22012b241f8e2f3cb811be2da245c0#f3b678e36f22012b241f8e2f3cb811be2da245c0" dependencies = [ "anyhow", "appflowy-plugin", @@ -225,7 +225,7 @@ dependencies = [ [[package]] name = "appflowy-plugin" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=c4ab1db44e96348f9b0770dd8ecc990f68ac415d#c4ab1db44e96348f9b0770dd8ecc990f68ac415d" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=f3b678e36f22012b241f8e2f3cb811be2da245c0#f3b678e36f22012b241f8e2f3cb811be2da245c0" dependencies = [ "anyhow", "cfg-if", diff --git a/frontend/appflowy_tauri/src-tauri/Cargo.toml b/frontend/appflowy_tauri/src-tauri/Cargo.toml index f128965d6d..1a26f66887 100644 --- a/frontend/appflowy_tauri/src-tauri/Cargo.toml +++ b/frontend/appflowy_tauri/src-tauri/Cargo.toml @@ -128,5 +128,5 @@ collab-user = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy- # To update the commit ID, run: # scripts/tool/update_local_ai_rev.sh new_rev_id # ⚠️⚠️⚠️️ -appflowy-local-ai = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "c4ab1db44e96348f9b0770dd8ecc990f68ac415d" } -appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "c4ab1db44e96348f9b0770dd8ecc990f68ac415d" } +appflowy-local-ai = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "f3b678e36f22012b241f8e2f3cb811be2da245c0" } +appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "f3b678e36f22012b241f8e2f3cb811be2da245c0" } diff --git a/frontend/appflowy_web_app/src-tauri/Cargo.lock b/frontend/appflowy_web_app/src-tauri/Cargo.lock index 54a5eb6fc0..3935849100 100644 --- a/frontend/appflowy_web_app/src-tauri/Cargo.lock +++ b/frontend/appflowy_web_app/src-tauri/Cargo.lock @@ -197,7 +197,7 @@ dependencies = [ [[package]] name = "appflowy-local-ai" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=c4ab1db44e96348f9b0770dd8ecc990f68ac415d#c4ab1db44e96348f9b0770dd8ecc990f68ac415d" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=f3b678e36f22012b241f8e2f3cb811be2da245c0#f3b678e36f22012b241f8e2f3cb811be2da245c0" dependencies = [ "anyhow", "appflowy-plugin", @@ -216,7 +216,7 @@ dependencies = [ [[package]] name = "appflowy-plugin" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=c4ab1db44e96348f9b0770dd8ecc990f68ac415d#c4ab1db44e96348f9b0770dd8ecc990f68ac415d" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=f3b678e36f22012b241f8e2f3cb811be2da245c0#f3b678e36f22012b241f8e2f3cb811be2da245c0" dependencies = [ "anyhow", "cfg-if", diff --git a/frontend/appflowy_web_app/src-tauri/Cargo.toml b/frontend/appflowy_web_app/src-tauri/Cargo.toml index 9591ac709f..30c9eafd20 100644 --- a/frontend/appflowy_web_app/src-tauri/Cargo.toml +++ b/frontend/appflowy_web_app/src-tauri/Cargo.toml @@ -128,6 +128,6 @@ collab-user = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy- # To update the commit ID, run: # scripts/tool/update_local_ai_rev.sh new_rev_id # ⚠️⚠️⚠️️ -appflowy-local-ai = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "c4ab1db44e96348f9b0770dd8ecc990f68ac415d" } -appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "c4ab1db44e96348f9b0770dd8ecc990f68ac415d" } +appflowy-local-ai = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "f3b678e36f22012b241f8e2f3cb811be2da245c0" } +appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "f3b678e36f22012b241f8e2f3cb811be2da245c0" } diff --git a/frontend/resources/translations/en.json b/frontend/resources/translations/en.json index 7e0eb80f5a..973e3f82b7 100644 --- a/frontend/resources/translations/en.json +++ b/frontend/resources/translations/en.json @@ -660,7 +660,8 @@ "restartLocalAI": "Restart Local AI", "disableLocalAIDialog": "Do you want to disable local AI?", "localAIToggleTitle": "Toggle to enable or disable local AI", - "fetchLocalModel": "Fetch local model configuration" + "fetchLocalModel": "Fetch local model configuration", + "openModelDirectory": "Open folder" } }, "planPage": { diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index 1715d2fc9a..a1894ba14a 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -197,7 +197,7 @@ dependencies = [ [[package]] name = "appflowy-local-ai" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=c4ab1db44e96348f9b0770dd8ecc990f68ac415d#c4ab1db44e96348f9b0770dd8ecc990f68ac415d" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=f3b678e36f22012b241f8e2f3cb811be2da245c0#f3b678e36f22012b241f8e2f3cb811be2da245c0" dependencies = [ "anyhow", "appflowy-plugin", @@ -216,7 +216,7 @@ dependencies = [ [[package]] name = "appflowy-plugin" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=c4ab1db44e96348f9b0770dd8ecc990f68ac415d#c4ab1db44e96348f9b0770dd8ecc990f68ac415d" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=f3b678e36f22012b241f8e2f3cb811be2da245c0#f3b678e36f22012b241f8e2f3cb811be2da245c0" dependencies = [ "anyhow", "cfg-if", diff --git a/frontend/rust-lib/Cargo.toml b/frontend/rust-lib/Cargo.toml index 2b4896b45e..eca9b34df1 100644 --- a/frontend/rust-lib/Cargo.toml +++ b/frontend/rust-lib/Cargo.toml @@ -147,5 +147,5 @@ collab-user = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy- # To update the commit ID, run: # scripts/tool/update_local_ai_rev.sh new_rev_id # ⚠️⚠️⚠️️ -appflowy-local-ai = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "c4ab1db44e96348f9b0770dd8ecc990f68ac415d" } -appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "c4ab1db44e96348f9b0770dd8ecc990f68ac415d" } +appflowy-local-ai = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "f3b678e36f22012b241f8e2f3cb811be2da245c0" } +appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "f3b678e36f22012b241f8e2f3cb811be2da245c0" } diff --git a/frontend/rust-lib/flowy-chat-pub/src/cloud.rs b/frontend/rust-lib/flowy-chat-pub/src/cloud.rs index 0054891140..5d64349b7c 100644 --- a/frontend/rust-lib/flowy-chat-pub/src/cloud.rs +++ b/frontend/rust-lib/flowy-chat-pub/src/cloud.rs @@ -15,7 +15,7 @@ use std::path::PathBuf; pub type ChatMessageStream = BoxStream<'static, Result>; pub type StreamAnswer = BoxStream<'static, Result>; -pub type StreamComplete = BoxStream<'static, Result>; +pub type StreamComplete = BoxStream<'static, Result>; #[async_trait] pub trait ChatCloudService: Send + Sync + 'static { fn create_chat( @@ -63,12 +63,12 @@ pub trait ChatCloudService: Send + Sync + 'static { limit: u64, ) -> FutureResult; - fn get_related_message( + async fn get_related_message( &self, workspace_id: &str, chat_id: &str, message_id: i64, - ) -> FutureResult; + ) -> Result; async fn stream_complete( &self, diff --git a/frontend/rust-lib/flowy-chat/src/chat.rs b/frontend/rust-lib/flowy-chat/src/chat.rs index 86727b5ce0..b6e08a0339 100644 --- a/frontend/rust-lib/flowy-chat/src/chat.rs +++ b/frontend/rust-lib/flowy-chat/src/chat.rs @@ -2,7 +2,7 @@ use crate::chat_manager::ChatUserService; use crate::entities::{ ChatMessageErrorPB, ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB, }; -use crate::middleware::chat_service_mw::ChatServiceMiddleware; +use crate::middleware::chat_service_mw::CloudServiceMiddleware; use crate::notification::{make_notification, ChatNotification}; use crate::persistence::{insert_chat_messages, select_chat_messages, ChatMessageTable}; use allo_isolate::Isolate; @@ -27,7 +27,7 @@ pub struct Chat { chat_id: String, uid: i64, user_service: Arc, - chat_service: Arc, + chat_service: Arc, prev_message_state: Arc>, latest_message_id: Arc, stop_stream: Arc, @@ -39,7 +39,7 @@ impl Chat { uid: i64, chat_id: String, user_service: Arc, - chat_service: Arc, + chat_service: Arc, ) -> Chat { Chat { uid, diff --git a/frontend/rust-lib/flowy-chat/src/chat_manager.rs b/frontend/rust-lib/flowy-chat/src/chat_manager.rs index 219a098f1f..a81e603ed6 100644 --- a/frontend/rust-lib/flowy-chat/src/chat_manager.rs +++ b/frontend/rust-lib/flowy-chat/src/chat_manager.rs @@ -1,7 +1,7 @@ use crate::chat::Chat; use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB}; use crate::local_ai::local_llm_chat::LocalAIController; -use crate::middleware::chat_service_mw::ChatServiceMiddleware; +use crate::middleware::chat_service_mw::CloudServiceMiddleware; use crate::persistence::{insert_chat, ChatTable}; use appflowy_plugin::manager::PluginManager; @@ -25,7 +25,7 @@ pub trait ChatUserService: Send + Sync + 'static { } pub struct ChatManager { - pub chat_service_wm: Arc, + pub cloud_service_wm: Arc, pub user_service: Arc, chats: Arc>>, pub local_ai_controller: Arc, @@ -46,21 +46,21 @@ impl ChatManager { cloud_service.clone(), )); - if local_ai_controller.can_init() { - if let Err(err) = local_ai_controller.initialize_chat_plugin(None) { + if local_ai_controller.can_init_plugin() { + if let Err(err) = local_ai_controller.initialize_ai_plugin(None) { error!("[AI Plugin] failed to initialize local ai: {:?}", err); } } // setup local chat service - let chat_service_wm = Arc::new(ChatServiceMiddleware::new( + let cloud_service_wm = Arc::new(CloudServiceMiddleware::new( user_service.clone(), cloud_service, local_ai_controller.clone(), )); Self { - chat_service_wm, + cloud_service_wm, user_service, chats: Arc::new(DashMap::new()), local_ai_controller, @@ -74,12 +74,14 @@ impl ChatManager { self.user_service.user_id().unwrap(), chat_id.to_string(), self.user_service.clone(), - self.chat_service_wm.clone(), + self.cloud_service_wm.clone(), )) }); trace!("[AI Plugin] notify open chat: {}", chat_id); - self.local_ai_controller.open_chat(chat_id); + if self.local_ai_controller.is_running() { + self.local_ai_controller.open_chat(chat_id); + } Ok(()) } @@ -108,7 +110,7 @@ impl ChatManager { pub async fn create_chat(&self, uid: &i64, chat_id: &str) -> Result, FlowyError> { let workspace_id = self.user_service.workspace_id()?; self - .chat_service_wm + .cloud_service_wm .create_chat(uid, &workspace_id, chat_id) .await?; save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?; @@ -117,7 +119,7 @@ impl ChatManager { self.user_service.user_id().unwrap(), chat_id.to_string(), self.user_service.clone(), - self.chat_service_wm.clone(), + self.cloud_service_wm.clone(), )); self.chats.insert(chat_id.to_string(), chat.clone()); Ok(chat) @@ -145,7 +147,7 @@ impl ChatManager { self.user_service.user_id().unwrap(), chat_id.to_string(), self.user_service.clone(), - self.chat_service_wm.clone(), + self.cloud_service_wm.clone(), )); self.chats.insert(chat_id.to_string(), chat.clone()); Ok(chat) diff --git a/frontend/rust-lib/flowy-chat/src/entities.rs b/frontend/rust-lib/flowy-chat/src/entities.rs index d0906638bf..00157f9fb4 100644 --- a/frontend/rust-lib/flowy-chat/src/entities.rs +++ b/frontend/rust-lib/flowy-chat/src/entities.rs @@ -410,3 +410,9 @@ pub struct LocalAIChatPB { #[pb(index = 3)] pub plugin_state: LocalAIPluginStatePB, } + +#[derive(Default, ProtoBuf, Clone, Debug)] +pub struct LocalModelStoragePB { + #[pb(index = 1)] + pub file_path: String, +} diff --git a/frontend/rust-lib/flowy-chat/src/event_handler.rs b/frontend/rust-lib/flowy-chat/src/event_handler.rs index 02c44bf0de..63c54a1f29 100644 --- a/frontend/rust-lib/flowy-chat/src/event_handler.rs +++ b/frontend/rust-lib/flowy-chat/src/event_handler.rs @@ -87,9 +87,15 @@ pub(crate) async fn get_related_question_handler( ) -> DataResult { let chat_manager = upgrade_chat_manager(chat_manager)?; let data = data.into_inner(); - let messages = chat_manager - .get_related_questions(&data.chat_id, data.message_id) - .await?; + let (tx, rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + let messages = chat_manager + .get_related_questions(&data.chat_id, data.message_id) + .await?; + let _ = tx.send(messages); + Ok::<_, FlowyError>(()) + }); + let messages = rx.await?; data_result_ok(messages) } @@ -338,3 +344,14 @@ pub(crate) async fn get_local_ai_state_handler( let enabled = chat_manager.local_ai_controller.is_enabled(); data_result_ok(LocalAIPB { enabled }) } + +#[tracing::instrument(level = "debug", skip_all, err)] +pub(crate) async fn get_model_storage_directory_handler( + chat_manager: AFPluginState>, +) -> DataResult { + let chat_manager = upgrade_chat_manager(chat_manager)?; + let file_path = chat_manager + .local_ai_controller + .get_model_storage_directory()?; + data_result_ok(LocalModelStoragePB { file_path }) +} diff --git a/frontend/rust-lib/flowy-chat/src/event_map.rs b/frontend/rust-lib/flowy-chat/src/event_map.rs index 8594e751f4..679e7b567c 100644 --- a/frontend/rust-lib/flowy-chat/src/event_map.rs +++ b/frontend/rust-lib/flowy-chat/src/event_map.rs @@ -11,7 +11,7 @@ use crate::event_handler::*; pub fn init(chat_manager: Weak) -> AFPlugin { let user_service = Arc::downgrade(&chat_manager.upgrade().unwrap().user_service); - let cloud_service = Arc::downgrade(&chat_manager.upgrade().unwrap().chat_service_wm); + let cloud_service = Arc::downgrade(&chat_manager.upgrade().unwrap().cloud_service_wm); let ai_tools = Arc::new(AITools::new(cloud_service, user_service)); AFPlugin::new() .name("Flowy-Chat") @@ -53,6 +53,10 @@ pub fn init(chat_manager: Weak) -> AFPlugin { ChatEvent::ToggleChatWithFile, toggle_local_ai_chat_file_handler, ) + .event( + ChatEvent::GetModelStorageDirectory, + get_model_storage_directory_handler, + ) } #[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)] @@ -126,4 +130,7 @@ pub enum ChatEvent { #[event()] ToggleChatWithFile = 20, + + #[event(output = "LocalModelStoragePB")] + GetModelStorageDirectory = 21, } diff --git a/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_chat.rs b/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_chat.rs index 6c0d089c40..09c5d36f41 100644 --- a/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_chat.rs +++ b/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_chat.rs @@ -1,7 +1,5 @@ use crate::chat_manager::ChatUserService; -use crate::entities::{ - ChatStatePB, LocalAIPluginStatePB, LocalModelResourcePB, ModelTypePB, RunningStatePB, -}; +use crate::entities::{LocalAIPluginStatePB, LocalModelResourcePB, RunningStatePB}; use crate::local_ai::local_llm_resource::{LLMResourceController, LLMResourceService}; use crate::notification::{make_notification, ChatNotification, APPFLOWY_AI_NOTIFICATION_KEY}; use anyhow::Error; @@ -100,7 +98,7 @@ impl LocalAIController { tokio::spawn(async move { while rx.recv().await.is_some() { if let Ok(chat_config) = cloned_llm_res.get_chat_config(rag_enabled) { - if let Err(err) = initialize_chat_plugin(&cloned_llm_chat, chat_config) { + if let Err(err) = initialize_chat_plugin(&cloned_llm_chat, chat_config, None) { error!("[AI Plugin] failed to setup plugin: {:?}", err); } } @@ -113,79 +111,49 @@ impl LocalAIController { self.llm_res.refresh_llm_resource().await } - pub fn initialize_chat_plugin( + pub fn initialize_ai_plugin( &self, ret: Option>, ) -> FlowyResult<()> { - let mut chat_config = self.llm_res.get_chat_config(self.is_rag_enabled())?; - let llm_chat = self.llm_chat.clone(); - tokio::spawn(async move { - trace!("[AI Plugin] config: {:?}", chat_config); - if is_apple_silicon().await.unwrap_or(false) { - chat_config = chat_config.with_device("gpu"); - } - match llm_chat.init_chat_plugin(chat_config).await { - Ok(_) => { - make_notification( - APPFLOWY_AI_NOTIFICATION_KEY, - ChatNotification::UpdateChatPluginState, - ) - .payload(ChatStatePB { - model_type: ModelTypePB::LocalAI, - available: true, - }) - .send(); - }, - Err(err) => { - make_notification( - APPFLOWY_AI_NOTIFICATION_KEY, - ChatNotification::UpdateChatPluginState, - ) - .payload(ChatStatePB { - model_type: ModelTypePB::LocalAI, - available: false, - }) - .send(); - error!("[AI Plugin] failed to setup plugin: {:?}", err); - }, - } - if let Some(ret) = ret { - let _ = ret.send(()); - } - }); + let chat_config = self.llm_res.get_chat_config(self.is_rag_enabled())?; + initialize_chat_plugin(&self.llm_chat, chat_config, ret)?; Ok(()) } /// Returns true if the local AI is enabled and ready to use. - pub fn can_init(&self) -> bool { + pub fn can_init_plugin(&self) -> bool { self.is_enabled() && self.llm_res.is_resource_ready() } + /// Indicate whether the local AI plugin is running. pub fn is_running(&self) -> bool { self.llm_chat.get_plugin_running_state().is_ready() } + /// Indicate whether the local AI is enabled. pub fn is_enabled(&self) -> bool { - self.store_preferences.get_bool(APPFLOWY_LOCAL_AI_ENABLED) + self + .store_preferences + .get_bool(APPFLOWY_LOCAL_AI_ENABLED) + .unwrap_or(true) } + /// Indicate whether the local AI chat is enabled. In the future, we can support multiple + /// AI plugin. pub fn is_chat_enabled(&self) -> bool { self .store_preferences .get_bool(APPFLOWY_LOCAL_AI_CHAT_ENABLED) + .unwrap_or(true) } pub fn is_rag_enabled(&self) -> bool { self .store_preferences - .get_bool(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED) + .get_bool_or_default(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED) } pub fn open_chat(&self, chat_id: &str) { - if !self.is_chat_enabled() { - return; - } - if !self.is_running() { return; } @@ -234,7 +202,7 @@ impl LocalAIController { let state = self.llm_res.use_local_llm(llm_id)?; // Re-initialize the plugin if the setting is updated and ready to use if self.llm_res.is_resource_ready() { - self.initialize_chat_plugin(None)?; + self.initialize_ai_plugin(None)?; } Ok(state) } @@ -270,14 +238,24 @@ impl LocalAIController { pub fn restart_chat_plugin(&self) { let rag_enabled = self.is_rag_enabled(); if let Ok(chat_config) = self.llm_res.get_chat_config(rag_enabled) { - if let Err(err) = initialize_chat_plugin(&self.llm_chat, chat_config) { + if let Err(err) = initialize_chat_plugin(&self.llm_chat, chat_config, None) { error!("[AI Plugin] failed to setup plugin: {:?}", err); } } } + pub fn get_model_storage_directory(&self) -> FlowyResult { + self + .llm_res + .user_model_folder() + .map(|path| path.to_string_lossy().to_string()) + } + pub async fn toggle_local_ai(&self) -> FlowyResult { - let enabled = !self.store_preferences.get_bool(APPFLOWY_LOCAL_AI_ENABLED); + let enabled = !self + .store_preferences + .get_bool(APPFLOWY_LOCAL_AI_ENABLED) + .unwrap_or(true); self .store_preferences .set_bool(APPFLOWY_LOCAL_AI_ENABLED, enabled)?; @@ -287,7 +265,7 @@ impl LocalAIController { if enabled { let chat_enabled = self .store_preferences - .get_bool(APPFLOWY_LOCAL_AI_CHAT_ENABLED); + .get_bool_or_default(APPFLOWY_LOCAL_AI_CHAT_ENABLED); self.enable_chat_plugin(chat_enabled).await?; } else { self.enable_chat_plugin(false).await?; @@ -298,7 +276,8 @@ impl LocalAIController { pub async fn toggle_local_ai_chat(&self) -> FlowyResult { let enabled = !self .store_preferences - .get_bool(APPFLOWY_LOCAL_AI_CHAT_ENABLED); + .get_bool(APPFLOWY_LOCAL_AI_CHAT_ENABLED) + .unwrap_or(true); self .store_preferences .set_bool(APPFLOWY_LOCAL_AI_CHAT_ENABLED, enabled)?; @@ -310,7 +289,7 @@ impl LocalAIController { pub async fn toggle_local_ai_chat_rag(&self) -> FlowyResult { let enabled = !self .store_preferences - .get_bool(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED); + .get_bool_or_default(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED); self .store_preferences .set_bool(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED, enabled)?; @@ -320,7 +299,7 @@ impl LocalAIController { async fn enable_chat_plugin(&self, enabled: bool) -> FlowyResult<()> { if enabled { let (tx, rx) = tokio::sync::oneshot::channel(); - if let Err(err) = self.initialize_chat_plugin(Some(tx)) { + if let Err(err) = self.initialize_ai_plugin(Some(tx)) { error!("[AI Plugin] failed to initialize local ai: {:?}", err); } let _ = rx.await; @@ -334,6 +313,7 @@ impl LocalAIController { fn initialize_chat_plugin( llm_chat: &Arc, mut chat_config: AIPluginConfig, + ret: Option>, ) -> FlowyResult<()> { let llm_chat = llm_chat.clone(); tokio::spawn(async move { @@ -342,29 +322,12 @@ fn initialize_chat_plugin( chat_config = chat_config.with_device("gpu"); } match llm_chat.init_chat_plugin(chat_config).await { - Ok(_) => { - make_notification( - APPFLOWY_AI_NOTIFICATION_KEY, - ChatNotification::UpdateChatPluginState, - ) - .payload(ChatStatePB { - model_type: ModelTypePB::LocalAI, - available: true, - }) - .send(); - }, - Err(err) => { - make_notification( - APPFLOWY_AI_NOTIFICATION_KEY, - ChatNotification::UpdateChatPluginState, - ) - .payload(ChatStatePB { - model_type: ModelTypePB::LocalAI, - available: false, - }) - .send(); - error!("[AI Plugin] failed to setup plugin: {:?}", err); - }, + Ok(_) => {}, + Err(err) => error!("[AI Plugin] failed to setup plugin: {:?}", err), + } + + if let Some(ret) = ret { + let _ = ret.send(()); } }); Ok(()) @@ -402,6 +365,6 @@ impl LLMResourceService for LLMResourceServiceImpl { fn is_rag_enabled(&self) -> bool { self .store_preferences - .get_bool(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED) + .get_bool_or_default(APPFLOWY_LOCAL_AI_CHAT_RAG_ENABLED) } } diff --git a/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_resource.rs b/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_resource.rs index 548c2a73a4..33119ea748 100644 --- a/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_resource.rs +++ b/frontend/rust-lib/flowy-chat/src/local_ai/local_llm_resource.rs @@ -478,7 +478,7 @@ impl LLMResourceController { self.resource_dir().map(|dir| dir.join(PLUGIN_DIR)) } - fn user_model_folder(&self) -> FlowyResult { + pub(crate) fn user_model_folder(&self) -> FlowyResult { self.resource_dir().map(|dir| dir.join(LLM_MODEL_DIR)) } diff --git a/frontend/rust-lib/flowy-chat/src/middleware/chat_service_mw.rs b/frontend/rust-lib/flowy-chat/src/middleware/chat_service_mw.rs index bf7f55bf1d..b4c2a1710e 100644 --- a/frontend/rust-lib/flowy-chat/src/middleware/chat_service_mw.rs +++ b/frontend/rust-lib/flowy-chat/src/middleware/chat_service_mw.rs @@ -7,7 +7,7 @@ use appflowy_plugin::error::PluginError; use flowy_chat_pub::cloud::{ ChatCloudService, ChatMessage, ChatMessageType, CompletionType, LocalAIConfig, MessageCursor, - RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer, StreamComplete, + RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer, StreamComplete, }; use flowy_error::{FlowyError, FlowyResult}; use futures::{stream, StreamExt, TryStreamExt}; @@ -17,13 +17,13 @@ use lib_infra::future::FutureResult; use std::path::PathBuf; use std::sync::Arc; -pub struct ChatServiceMiddleware { - pub cloud_service: Arc, +pub struct CloudServiceMiddleware { + cloud_service: Arc, user_service: Arc, local_llm_controller: Arc, } -impl ChatServiceMiddleware { +impl CloudServiceMiddleware { pub fn new( user_service: Arc, cloud_service: Arc, @@ -67,7 +67,7 @@ impl ChatServiceMiddleware { } #[async_trait] -impl ChatCloudService for ChatServiceMiddleware { +impl ChatCloudService for CloudServiceMiddleware { fn create_chat( &self, uid: &i64, @@ -177,23 +177,34 @@ impl ChatCloudService for ChatServiceMiddleware { .get_chat_messages(workspace_id, chat_id, offset, limit) } - fn get_related_message( + async fn get_related_message( &self, workspace_id: &str, chat_id: &str, message_id: i64, - ) -> FutureResult { + ) -> Result { if self.local_llm_controller.is_running() { - FutureResult::new(async move { - Ok(RepeatedRelatedQuestion { - message_id, - items: vec![], + let questions = self + .local_llm_controller + .get_related_question(chat_id) + .await + .map_err(|err| FlowyError::local_ai().with_context(err))? + .into_iter() + .map(|content| RelatedQuestion { + content, + metadata: None, }) + .collect::>(); + + Ok(RepeatedRelatedQuestion { + message_id, + items: questions, }) } else { self .cloud_service .get_related_message(workspace_id, chat_id, message_id) + .await } } @@ -204,9 +215,21 @@ impl ChatCloudService for ChatServiceMiddleware { complete_type: CompletionType, ) -> Result { if self.local_llm_controller.is_running() { - return Err( - FlowyError::not_support().with_context("completion with local ai is not supported yet"), - ); + match self + .local_llm_controller + .complete_text(text, complete_type as u8) + .await + { + Ok(stream) => Ok( + stream + .map_err(|err| FlowyError::local_ai().with_context(err)) + .boxed(), + ), + Err(err) => { + self.handle_plugin_error(err); + Ok(stream::once(async { Err(FlowyError::local_ai_unavailable()) }).boxed()) + }, + } } else { self .cloud_service diff --git a/frontend/rust-lib/flowy-chat/src/middleware/mod.rs b/frontend/rust-lib/flowy-chat/src/middleware/mod.rs index 0965215e71..e1c0f454da 100644 --- a/frontend/rust-lib/flowy-chat/src/middleware/mod.rs +++ b/frontend/rust-lib/flowy-chat/src/middleware/mod.rs @@ -1 +1 @@ -pub mod chat_service_mw; +pub(crate) mod chat_service_mw; diff --git a/frontend/rust-lib/flowy-chat/src/tools.rs b/frontend/rust-lib/flowy-chat/src/tools.rs index 493416fff7..0219bbe230 100644 --- a/frontend/rust-lib/flowy-chat/src/tools.rs +++ b/frontend/rust-lib/flowy-chat/src/tools.rs @@ -11,7 +11,6 @@ use lib_infra::isolate_stream::IsolateSink; use std::sync::{Arc, Weak}; use tokio::select; -use tracing::trace; pub struct AITools { tasks: Arc>>, @@ -109,11 +108,10 @@ impl ToolTask { match result { Some(Ok(data)) => { let s = String::from_utf8(data.to_vec()).unwrap_or_default(); - trace!("stream completion data: {}", s); let _ = sink.send(format!("data:{}", s)).await; }, Some(Err(error)) => { - handle_error(&mut sink, FlowyError::from(error)).await; + handle_error(&mut sink, error).await; return; }, None => { diff --git a/frontend/rust-lib/flowy-core/src/integrate/trait_impls.rs b/frontend/rust-lib/flowy-core/src/integrate/trait_impls.rs index 18ec68c04d..f44dabce90 100644 --- a/frontend/rust-lib/flowy-core/src/integrate/trait_impls.rs +++ b/frontend/rust-lib/flowy-core/src/integrate/trait_impls.rs @@ -685,21 +685,17 @@ impl ChatCloudService for ServerProvider { }) } - fn get_related_message( + async fn get_related_message( &self, workspace_id: &str, chat_id: &str, message_id: i64, - ) -> FutureResult { - let workspace_id = workspace_id.to_string(); - let chat_id = chat_id.to_string(); - let server = self.get_server(); - FutureResult::new(async move { - server? - .chat_service() - .get_related_message(&workspace_id, &chat_id, message_id) - .await - }) + ) -> Result { + self + .get_server()? + .chat_service() + .get_related_message(workspace_id, chat_id, message_id) + .await } async fn generate_answer( diff --git a/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs b/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs index e69201a390..80fce796ef 100644 --- a/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs +++ b/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs @@ -146,24 +146,19 @@ where }) } - fn get_related_message( + async fn get_related_message( &self, workspace_id: &str, chat_id: &str, message_id: i64, - ) -> FutureResult { - let workspace_id = workspace_id.to_string(); - let chat_id = chat_id.to_string(); + ) -> Result { let try_get_client = self.inner.try_get_client(); + let resp = try_get_client? + .get_chat_related_question(workspace_id, chat_id, message_id) + .await + .map_err(FlowyError::from)?; - FutureResult::new(async move { - let resp = try_get_client? - .get_chat_related_question(&workspace_id, &chat_id, message_id) - .await - .map_err(FlowyError::from)?; - - Ok(resp) - }) + Ok(resp) } async fn stream_complete( @@ -181,7 +176,8 @@ where .try_get_client()? .stream_completion_text(workspace_id, params) .await - .map_err(FlowyError::from)?; + .map_err(FlowyError::from)? + .map_err(FlowyError::from); Ok(stream.boxed()) } diff --git a/frontend/rust-lib/flowy-server/src/default_impl.rs b/frontend/rust-lib/flowy-server/src/default_impl.rs index 955cf653da..662853d854 100644 --- a/frontend/rust-lib/flowy-server/src/default_impl.rs +++ b/frontend/rust-lib/flowy-server/src/default_impl.rs @@ -66,15 +66,13 @@ impl ChatCloudService for DefaultChatCloudServiceImpl { }) } - fn get_related_message( + async fn get_related_message( &self, _workspace_id: &str, _chat_id: &str, _message_id: i64, - ) -> FutureResult { - FutureResult::new(async move { - Err(FlowyError::not_support().with_context("Chat is not supported in local server.")) - }) + ) -> Result { + Err(FlowyError::not_support().with_context("Chat is not supported in local server.")) } async fn generate_answer( diff --git a/frontend/rust-lib/flowy-sqlite/src/kv/kv.rs b/frontend/rust-lib/flowy-sqlite/src/kv/kv.rs index d10da70823..da35facaf2 100644 --- a/frontend/rust-lib/flowy-sqlite/src/kv/kv.rs +++ b/frontend/rust-lib/flowy-sqlite/src/kv/kv.rs @@ -63,7 +63,7 @@ impl KVStorePreferences { } /// Get a bool value of a key - pub fn get_bool(&self, key: &str) -> bool { + pub fn get_bool_or_default(&self, key: &str) -> bool { self .get_key_value(key) .and_then(|kv| kv.value) @@ -71,6 +71,13 @@ impl KVStorePreferences { .unwrap_or(false) } + pub fn get_bool(&self, key: &str) -> Option { + self + .get_key_value(key) + .and_then(|kv| kv.value) + .and_then(|v| v.parse::().ok()) + } + /// Get a i64 value of a key pub fn get_i64(&self, key: &str) -> Option { self @@ -157,8 +164,8 @@ mod tests { assert_eq!(store.get_str("2"), None); store.set_bool("1", true).unwrap(); - assert!(store.get_bool("1")); - assert!(!store.get_bool("2")); + assert!(store.get_bool_or_default("1")); + assert!(!store.get_bool_or_default("2")); store.set_i64("1", 1).unwrap(); assert_eq!(store.get_i64("1").unwrap(), 1); diff --git a/frontend/rust-lib/flowy-user/src/migrations/session_migration.rs b/frontend/rust-lib/flowy-user/src/migrations/session_migration.rs index 172f88209f..df477f4a33 100644 --- a/frontend/rust-lib/flowy-user/src/migrations/session_migration.rs +++ b/frontend/rust-lib/flowy-user/src/migrations/session_migration.rs @@ -10,7 +10,7 @@ pub fn migrate_session_with_user_uuid( session_cache_key: &str, store_preferences: &Arc, ) -> Option { - if !store_preferences.get_bool(MIGRATION_USER_NO_USER_UUID) + if !store_preferences.get_bool_or_default(MIGRATION_USER_NO_USER_UUID) && store_preferences .set_bool(MIGRATION_USER_NO_USER_UUID, true) .is_ok() diff --git a/frontend/rust-lib/flowy-user/src/services/authenticate_user.rs b/frontend/rust-lib/flowy-user/src/services/authenticate_user.rs index 0065bddd14..1df4fda3e2 100644 --- a/frontend/rust-lib/flowy-user/src/services/authenticate_user.rs +++ b/frontend/rust-lib/flowy-user/src/services/authenticate_user.rs @@ -40,7 +40,10 @@ impl AuthenticateUser { } pub fn vacuum_database_if_need(&self) { - if !self.store_preferences.get_bool(SQLITE_VACUUM_042) { + if !self + .store_preferences + .get_bool_or_default(SQLITE_VACUUM_042) + { if let Ok(session) = self.get_session() { let _ = self.store_preferences.set_bool(SQLITE_VACUUM_042, true); if let Ok(conn) = self.database.get_connection(session.user_id) {