[PUI] SSO Support (#6333)

* Add sso buttons
Fixes #5753

* Added more icons

* fix callback url

* made heading dynamic

* allow either sso or normal reg

* Added SSO registration

* added divider

* added preferred ui API

* fix test

* fix update function

* refactor

* fix naming

* fix import

* add coverage ignore

* more ignore

* fixed missing key

* renamed button

* revert coverage statements

* set prefered mode before sso login

* added dynamic login redirect

* fixed test assert

* use API Endpoints instead of hardcoding

* fix lookup
This commit is contained in:
Matthias Mair 2024-02-02 09:15:13 +00:00 committed by GitHub
parent c9c93bce39
commit be30cec2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 337 additions and 118 deletions

View File

@ -1148,5 +1148,4 @@ if CUSTOM_FLAGS:
# Magic login django-sesame
SESAME_MAX_AGE = 300
# LOGIN_REDIRECT_URL = f"/{FRONTEND_URL_BASE}/logged-in/"
LOGIN_REDIRECT_URL = '/index/'
LOGIN_REDIRECT_URL = '/api/auth/login-redirect/'

View File

@ -1247,9 +1247,6 @@ class MagicLoginTest(InvenTreeTestCase):
# Check that the login works
resp = self.client.get(reverse('sesame-login') + '?sesame=' + token)
self.assertEqual(resp.status_code, 302)
self.assertEqual(resp.url, '/index/')
# Note: 2023-08-08 - This test has been changed because "platform UI" is not generally available yet
# TODO: In the future, the URL comparison will need to be reverted
# self.assertEqual(resp.url, f'/{settings.FRONTEND_URL_BASE}/logged-in/')
self.assertEqual(resp.url, '/api/auth/login-redirect/')
# And we should be logged in again
self.assertEqual(resp.wsgi_request.user, self.user)

View File

@ -35,6 +35,7 @@ from order.urls import order_urls
from part.urls import part_urls
from plugin.urls import get_plugin_urls
from stock.urls import stock_urls
from web.urls import api_urls as web_api_urls
from web.urls import urlpatterns as platform_urls
from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView
@ -84,6 +85,7 @@ apipatterns = [
path('report/', include(report.api.report_api_urls)),
path('user/', include(users.api.user_urls)),
path('admin/', include(common.api.admin_api_urls)),
path('web/', include(web_api_urls)),
# Plugin endpoints
path('', include(plugin.api.plugin_api_urls)),
# Common endpoints endpoint
@ -149,6 +151,11 @@ apipatterns = [
name='social_account_disconnect',
),
path('logout/', users.api.Logout.as_view(), name='api-logout'),
path(
'login-redirect/',
users.api.LoginRedirect.as_view(),
name='api-login-redirect',
),
path('', include('dj_rest_auth.urls')),
]),
),

View File

@ -6,6 +6,7 @@ import logging
from django.contrib.auth import get_user, login
from django.contrib.auth.models import Group, User
from django.urls import include, path, re_path
from django.views.generic.base import RedirectView
from dj_rest_auth.views import LogoutView
from rest_framework import exceptions, permissions
@ -22,6 +23,7 @@ from InvenTree.mixins import (
RetrieveUpdateDestroyAPI,
)
from InvenTree.serializers import ExendedUserSerializer, UserCreateSerializer
from InvenTree.settings import FRONTEND_URL_BASE
from users.models import ApiToken, Owner, RuleSet, check_user_role
from users.serializers import GroupSerializer, OwnerSerializer
@ -279,6 +281,17 @@ class GetAuthToken(APIView):
raise exceptions.NotAuthenticated()
class LoginRedirect(RedirectView):
"""Redirect to the correct starting page after backend login."""
def get_redirect_url(self, *args, **kwargs):
"""Return the URL to redirect to."""
session = self.request.session
if session.get('preferred_method', 'cui') == 'pui':
return f'/{FRONTEND_URL_BASE}/logged-in/'
return '/index/'
user_urls = [
path('roles/', RoleDetails.as_view(), name='api-user-roles'),
path('token/', GetAuthToken.as_view(), name='api-token'),

View File

@ -5,8 +5,10 @@ import os
from pathlib import Path
from unittest import mock
from django.urls import reverse
from InvenTree.config import get_frontend_settings
from InvenTree.unit_test import InvenTreeTestCase
from InvenTree.unit_test import InvenTreeAPITestCase, InvenTreeTestCase
from .templatetags import spa_helper
@ -73,3 +75,26 @@ class TemplateTagTest(InvenTreeTestCase):
rsp = get_frontend_settings(False)
self.assertFalse('show_server_selector' in rsp)
self.assertEqual(rsp['server_list'], ['aa', 'bb'])
class TestWebHelpers(InvenTreeAPITestCase):
"""Tests for the web helpers."""
def test_ui_preference(self):
"""Test the UI preference API."""
url = reverse('api-ui-preference')
# Test default
resp = self.get(url)
data = json.loads(resp.content)
self.assertTrue(data['cui'])
self.assertFalse(data['pui'])
self.assertEqual(data['preferred_method'], 'cui')
# Set to PUI
resp = self.put(url, {'preferred_method': 'pui'})
data = json.loads(resp.content)
self.assertEqual(resp.status_code, 200)
self.assertFalse(data['cui'])
self.assertTrue(data['pui'])
self.assertEqual(data['preferred_method'], 'pui')

View File

@ -1,11 +1,16 @@
"""URLs for web app."""
from django.conf import settings
from django.http import JsonResponse
from django.shortcuts import redirect
from django.urls import include, path, re_path
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import TemplateView
from rest_framework import permissions, serializers
from InvenTree.mixins import RetrieveUpdateAPI
class RedirectAssetView(TemplateView):
"""View to redirect to static asset."""
@ -17,6 +22,55 @@ class RedirectAssetView(TemplateView):
)
class PreferredSerializer(serializers.Serializer):
"""Serializer for the preferred serializer session setting."""
preferred_method = serializers.ChoiceField(choices=['cui', 'pui'])
pui = serializers.SerializerMethodField(read_only=True)
cui = serializers.SerializerMethodField(read_only=True)
def get_pui(self, obj):
"""Return true if preferred method is PUI."""
return obj['preferred_method'] == 'pui'
def get_cui(self, obj):
"""Return true if preferred method is CUI."""
return obj['preferred_method'] == 'cui'
class Meta:
"""Meta class for PreferedSerializer."""
fields = '__all__'
class PreferredUiView(RetrieveUpdateAPI):
"""Set preferred UI (CIU/PUI)."""
permission_classes = [permissions.AllowAny]
serializer_class = PreferredSerializer
http_method_names = ['get', 'post', 'put', 'head', 'options']
def retrieve(self, request, *args, **kwargs):
"""Retrieve the preferred UI method."""
session = self.request.session
session['preferred_method'] = session.get('preferred_method', 'cui')
serializer = self.get_serializer(data=dict(session))
serializer.is_valid(raise_exception=True)
return JsonResponse(serializer.data)
def update(self, request, *args, **kwargs):
"""Update the preferred UI method."""
serializer = self.get_serializer(data=self.clean_data(request.data))
serializer.is_valid(raise_exception=True)
# Run update
session = self.request.session
session['preferred_method'] = serializer.validated_data['preferred_method']
session.modified = True
return JsonResponse(serializer.data)
spa_view = ensure_csrf_cookie(TemplateView.as_view(template_name='web/index.html'))
assets_path = path('assets/<path:path>', RedirectAssetView.as_view())
@ -37,3 +91,8 @@ urlpatterns = [
assets_path,
path(settings.FRONTEND_URL_BASE, spa_view, name='platform'),
]
api_urls = [
# UI Preference
path('ui_preference/', PreferredUiView.as_view(), name='api-ui-preference')
]

View File

@ -0,0 +1,64 @@
import { Button } from '@mantine/core';
import {
IconBrandAzure,
IconBrandBitbucket,
IconBrandDiscord,
IconBrandFacebook,
IconBrandFlickr,
IconBrandGithub,
IconBrandGitlab,
IconBrandGoogle,
IconBrandReddit,
IconBrandTwitch,
IconBrandTwitter,
IconLogin
} from '@tabler/icons-react';
import { api } from '../../App';
import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { apiUrl } from '../../states/ApiState';
import { Provider } from '../../states/states';
const brandIcons: { [key: string]: JSX.Element } = {
google: <IconBrandGoogle />,
github: <IconBrandGithub />,
facebook: <IconBrandFacebook />,
discord: <IconBrandDiscord />,
twitter: <IconBrandTwitter />,
bitbucket: <IconBrandBitbucket />,
flickr: <IconBrandFlickr />,
gitlab: <IconBrandGitlab />,
reddit: <IconBrandReddit />,
twitch: <IconBrandTwitch />,
microsoft: <IconBrandAzure />
};
export function SsoButton({ provider }: { provider: Provider }) {
function login() {
// set preferred provider
api
.put(
apiUrl(ApiEndpoints.ui_preference),
{ preferred_method: 'pui' },
{ headers: { Authorization: '' } }
)
.then(() => {
// redirect to login
window.location.href = provider.login;
});
}
return (
<Button
leftIcon={getBrandIcon(provider)}
radius="xl"
component="a"
onClick={login}
>
{provider.display_name}{' '}
</Button>
);
}
function getBrandIcon(provider: Provider) {
return brandIcons[provider.id] || <IconLogin />;
}

View File

@ -2,6 +2,7 @@ import { Trans, t } from '@lingui/macro';
import {
Anchor,
Button,
Divider,
Group,
Loader,
PasswordInput,
@ -21,6 +22,7 @@ import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { doBasicLogin, doSimpleLogin } from '../../functions/auth';
import { apiUrl, useServerApiState } from '../../states/ApiState';
import { useSessionState } from '../../states/SessionState';
import { SsoButton } from '../buttons/SSOButton';
export function AuthenticationForm() {
const classicForm = useForm({
@ -83,76 +85,93 @@ export function AuthenticationForm() {
}
return (
<form onSubmit={classicForm.onSubmit(() => {})}>
{classicLoginMode ? (
<Stack spacing={0}>
<TextInput
required
label={t`Username`}
placeholder={t`Your username`}
{...classicForm.getInputProps('username')}
/>
<PasswordInput
required
label={t`Password`}
placeholder={t`Your password`}
{...classicForm.getInputProps('password')}
/>
{auth_settings?.password_forgotten_enabled === true && (
<Group position="apart" mt="0">
<Anchor
component="button"
type="button"
color="dimmed"
size="xs"
onClick={() => navigate('/reset-password')}
>
<Trans>Reset password</Trans>
</Anchor>
</Group>
)}
</Stack>
) : (
<Stack>
<TextInput
required
label={t`Email`}
description={t`We will send you a link to login - if you are registered`}
placeholder="email@example.org"
{...simpleForm.getInputProps('email')}
/>
</Stack>
)}
<>
{auth_settings?.sso_enabled === true ? (
<>
<Group grow mb="md" mt="md">
{auth_settings.providers.map((provider) => (
<SsoButton provider={provider} key={provider.id} />
))}
</Group>
<Group position="apart" mt="xl">
<Anchor
component="button"
type="button"
color="dimmed"
size="xs"
onClick={() => setMode.toggle()}
>
{classicLoginMode ? (
<Trans>Send me an email</Trans>
) : (
<Trans>Use username and password</Trans>
)}
</Anchor>
<Button type="submit" disabled={isLoggingIn} onClick={handleLogin}>
{isLoggingIn ? (
<Loader size="sm" />
) : (
<>
{classicLoginMode ? (
<Trans>Log In</Trans>
) : (
<Trans>Send Email</Trans>
)}
</>
)}
</Button>
</Group>
</form>
<Divider
label={t`Or continue with other methods`}
labelPosition="center"
my="lg"
/>
</>
) : null}
<form onSubmit={classicForm.onSubmit(() => {})}>
{classicLoginMode ? (
<Stack spacing={0}>
<TextInput
required
label={t`Username`}
placeholder={t`Your username`}
{...classicForm.getInputProps('username')}
/>
<PasswordInput
required
label={t`Password`}
placeholder={t`Your password`}
{...classicForm.getInputProps('password')}
/>
{auth_settings?.password_forgotten_enabled === true && (
<Group position="apart" mt="0">
<Anchor
component="button"
type="button"
color="dimmed"
size="xs"
onClick={() => navigate('/reset-password')}
>
<Trans>Reset password</Trans>
</Anchor>
</Group>
)}
</Stack>
) : (
<Stack>
<TextInput
required
label={t`Email`}
description={t`We will send you a link to login - if you are registered`}
placeholder="email@example.org"
{...simpleForm.getInputProps('email')}
/>
</Stack>
)}
<Group position="apart" mt="xl">
<Anchor
component="button"
type="button"
color="dimmed"
size="xs"
onClick={() => setMode.toggle()}
>
{classicLoginMode ? (
<Trans>Send me an email</Trans>
) : (
<Trans>Use username and password</Trans>
)}
</Anchor>
<Button type="submit" disabled={isLoggingIn} onClick={handleLogin}>
{isLoggingIn ? (
<Loader size="sm" />
) : (
<>
{classicLoginMode ? (
<Trans>Log In</Trans>
) : (
<Trans>Send Email</Trans>
)}
</>
)}
</Button>
</Group>
</form>
</>
);
}
@ -161,6 +180,7 @@ export function RegistrationForm() {
initialValues: { username: '', email: '', password1: '', password2: '' }
});
const navigate = useNavigate();
const [auth_settings] = useServerApiState((state) => [state.auth_settings]);
const [isRegistering, setIsRegistering] = useState<boolean>(false);
function handleRegistration() {
@ -201,47 +221,63 @@ export function RegistrationForm() {
});
}
const both_reg_enabled =
auth_settings?.registration_enabled && auth_settings?.sso_registration;
return (
<form onSubmit={registrationForm.onSubmit(() => {})}>
<Stack spacing={0}>
<TextInput
required
label={t`Username`}
placeholder={t`Your username`}
{...registrationForm.getInputProps('username')}
/>
<TextInput
required
label={t`Email`}
description={t`This will be used for a confirmation`}
placeholder="email@example.org"
{...registrationForm.getInputProps('email')}
/>
<PasswordInput
required
label={t`Password`}
placeholder={t`Your password`}
{...registrationForm.getInputProps('password1')}
/>
<PasswordInput
required
label={t`Password repeat`}
placeholder={t`Repeat password`}
{...registrationForm.getInputProps('password2')}
/>
</Stack>
<>
{auth_settings?.registration_enabled && (
<form onSubmit={registrationForm.onSubmit(() => {})}>
<Stack spacing={0}>
<TextInput
required
label={t`Username`}
placeholder={t`Your username`}
{...registrationForm.getInputProps('username')}
/>
<TextInput
required
label={t`Email`}
description={t`This will be used for a confirmation`}
placeholder="email@example.org"
{...registrationForm.getInputProps('email')}
/>
<PasswordInput
required
label={t`Password`}
placeholder={t`Your password`}
{...registrationForm.getInputProps('password1')}
/>
<PasswordInput
required
label={t`Password repeat`}
placeholder={t`Repeat password`}
{...registrationForm.getInputProps('password2')}
/>
</Stack>
<Group position="apart" mt="xl">
<Button
type="submit"
disabled={isRegistering}
onClick={handleRegistration}
fullWidth
>
<Trans>Register</Trans>
</Button>
</Group>
</form>
<Group position="apart" mt="xl">
<Button
type="submit"
disabled={isRegistering}
onClick={handleRegistration}
fullWidth
>
<Trans>Register</Trans>
</Button>
</Group>
</form>
)}
{both_reg_enabled && (
<Divider label={t`Or use SSO`} labelPosition="center" my="lg" />
)}
{auth_settings?.sso_registration === true && (
<Group grow mb="md" mt="md">
{auth_settings.providers.map((provider) => (
<SsoButton provider={provider} key={provider.id} />
))}
</Group>
)}
</>
);
}
@ -253,8 +289,12 @@ export function ModeSelector({
setMode: any;
}) {
const [auth_settings] = useServerApiState((state) => [state.auth_settings]);
const registration_enabled =
auth_settings?.registration_enabled ||
auth_settings?.sso_registration ||
false;
if (auth_settings?.registration_enabled === false) return null;
if (registration_enabled === false) return null;
return (
<Text ta="center" size={'xs'} mt={'md'}>
{loginMode ? (

View File

@ -26,6 +26,7 @@ export enum ApiEndpoints {
user_logout = 'auth/logout/',
user_register = 'auth/registration/',
// Generic API endpoints
currency_list = 'currency/exchange/',
currency_refresh = 'currency/refresh/',
task_overview = 'background-task/',
@ -43,6 +44,8 @@ export enum ApiEndpoints {
sso_providers = 'auth/providers/',
group_list = 'user/group/',
owner_list = 'user/owner/',
// Build API endpoints
build_order_list = 'build/',
build_order_attachment_list = 'build/attachment/',
build_line_list = 'build/line/',
@ -61,6 +64,7 @@ export enum ApiEndpoints {
part_attachment_list = 'part/attachment/',
part_test_template_list = 'part/test-template/',
// Company API endpoints
company_list = 'company/',
contact_list = 'company/contact/',
address_list = 'company/address/',
@ -76,6 +80,8 @@ export enum ApiEndpoints {
stock_location_list = 'stock/location/',
stock_location_tree = 'stock/location/tree/',
stock_attachment_list = 'stock/attachment/',
// Order API endpoints
purchase_order_list = 'order/po/',
purchase_order_line_list = 'order/po-line/',
purchase_order_attachment_list = 'order/po/attachment/',
@ -84,12 +90,17 @@ export enum ApiEndpoints {
sales_order_shipment_list = 'order/so/shipment/',
return_order_list = 'order/ro/',
return_order_attachment_list = 'order/ro/attachment/',
// Plugin API endpoints
plugin_list = 'plugins/',
plugin_setting_list = 'plugins/:plugin/settings/',
plugin_registry_status = 'plugins/status/',
plugin_install = 'plugins/install/',
plugin_reload = 'plugins/reload/',
// Miscellaneous API endpoints
error_report_list = 'error-report/',
project_code_list = 'project-code/',
custom_unit_list = 'units/'
custom_unit_list = 'units/',
ui_preference = 'web/ui_preference/'
}

View File

@ -71,7 +71,11 @@ export default function Login() {
<>
<Paper radius="md" p="xl" withBorder>
<Text size="lg" weight={500}>
<Trans>Welcome, log in below</Trans>
{loginMode ? (
<Trans>Welcome, log in below</Trans>
) : (
<Trans>Register below</Trans>
)}
</Text>
{loginMode ? <AuthenticationForm /> : <RegistrationForm />}
<ModeSelector loginMode={loginMode} setMode={setMode} />