diff --git a/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx b/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx index 8bba5170..002408fc 100644 --- a/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx +++ b/new-ui/src/pages/compact/CompactLocationsPage/CompactLocationsPage.tsx @@ -27,34 +27,42 @@ export const CompactLocationsPage = () => { const routeData = useLoaderData({ from: '/compact/' }); + const { data: instances } = useQuery(getInstancesQueryOptions); + + const allInstances = instances ?? routeData.instances; + const allTunnels = routeData.tunnels; + const queryInstanceId = useMemo(() => { if (!isPresent(selection)) return routeData.instances[0].id; - if (selection.kind === 'instance') return selection.data.id; - return selection.data.instance_id; - }, [selection, routeData.instances]); + if (selection.kind === 'instance') return selection.id; + return ( + allTunnels.find((t) => t.id === selection.id)?.instance_id ?? + routeData.instances[0].id + ); + }, [selection, routeData.instances, allTunnels]); const { data: locations } = useQuery(getLocationsQueryOptions(queryInstanceId)); - const { data: instances } = useQuery(getInstancesQueryOptions); - const instanceInfo = useMemo(() => { - const allInstances = instances ?? routeData.instances; if (!isPresent(selection)) return allInstances[0]; if (selection.kind === 'instance') - return allInstances.find((i) => i.id === selection.data.id); - return allInstances.find((i) => i.id === selection.data.instance_id); - }, [selection, instances, routeData.instances]); + return allInstances.find((i) => i.id === selection.id); + const tunnel = allTunnels.find((t) => t.id === selection.id); + return tunnel ? allInstances.find((i) => i.id === tunnel.instance_id) : undefined; + }, [selection, allInstances, allTunnels]); const displayedLocations = useMemo(() => { if (!isPresent(selection) || selection.kind === 'instance') { return locations ?? routeData.locations; } - return [selection.data]; - }, [selection, locations, routeData.locations]); + const tunnel = allTunnels.find((t) => t.id === selection.id); + return tunnel ? [tunnel] : []; + }, [selection, locations, routeData.locations, allTunnels]); useEffect(() => { + if (selection?.kind === 'tunnel') return; if (selection === null || instanceInfo === undefined) { - setViewSelection({ kind: 'instance', data: routeData.instances[0] }); + setViewSelection({ kind: 'instance', id: routeData.instances[0].id }); } }, [routeData.instances, instanceInfo, selection, setViewSelection]); diff --git a/new-ui/src/pages/compact/CompactLocationsPage/components/InstanceSwitcher.tsx b/new-ui/src/pages/compact/CompactLocationsPage/components/InstanceSwitcher.tsx index 87eff3ec..1394885b 100644 --- a/new-ui/src/pages/compact/CompactLocationsPage/components/InstanceSwitcher.tsx +++ b/new-ui/src/pages/compact/CompactLocationsPage/components/InstanceSwitcher.tsx @@ -28,7 +28,7 @@ export const InstanceSwitcher = () => { options: instances.map((instance) => ({ key: instance.id, label: instance.name, - value: { kind: 'instance', data: instance }, + value: { kind: 'instance', id: instance.id }, })), }; @@ -38,7 +38,7 @@ export const InstanceSwitcher = () => { options: tunnels.map((tunnel) => ({ key: tunnel.id ?? tunnel.name, label: tunnel.name, - value: { kind: 'tunnel', data: tunnel }, + value: { kind: 'tunnel', id: tunnel.id }, })), }; @@ -54,13 +54,9 @@ export const InstanceSwitcher = () => { if (!isPresent(selectedInstance)) return undefined; for (const group of groups) { const found = group.options.find((o) => { - if (selectedInstance.kind === 'instance' && o.value.kind === 'instance') { - return o.value.data.id === selectedInstance.data.id; - } - if (selectedInstance.kind === 'tunnel' && o.value.kind === 'tunnel') { - return o.value.data.id === selectedInstance.data.id; - } - return false; + return ( + o.value.kind === selectedInstance.kind && o.value.id === selectedInstance.id + ); }); if (found) return found; } diff --git a/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx b/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx index e7b29482..c71f1f8a 100644 --- a/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx +++ b/new-ui/src/pages/full/OverviewPage/OverviewPage.tsx @@ -21,11 +21,27 @@ export const OverviewPage = () => { const { instances, tunnels } = useAppData(); const { viewSelection: selection } = useAppData(); + const selectedTunnel = useMemo( + () => + selection?.kind === 'tunnel' + ? tunnels.find((t) => t.id === selection.id) + : undefined, + [selection, tunnels], + ); + + const selectedInstance = useMemo( + () => + selection?.kind === 'instance' + ? instances.find((i) => i.id === selection.id) + : undefined, + [selection, instances], + ); + const queryInstanceId = useMemo(() => { if (!isPresent(selection)) return instances[0].id; - if (selection.kind === 'instance') return selection.data.id; - return selection.data.instance_id; - }, [selection, instances]); + if (selection.kind === 'instance') return selection.id; + return selectedTunnel?.instance_id ?? instances[0].id; + }, [selection, instances, selectedTunnel]); const { data: locations } = useQuery(getLocationsQueryOptions(queryInstanceId)); @@ -33,8 +49,8 @@ export const OverviewPage = () => { if (!isPresent(selection) || selection.kind === 'instance') { return locations ?? []; } - return [selection.data]; - }, [selection, locations]); + return selectedTunnel ? [selectedTunnel] : []; + }, [selection, locations, selectedTunnel]); return ( @@ -53,10 +69,8 @@ export const OverviewPage = () => {
{displayedLocations.map((location) => { - let instance: InstanceInfo | undefined; - if (selection?.kind === 'instance') { - instance = selection.data; - } + const instance: InstanceInfo | undefined = + selection?.kind === 'instance' ? selectedInstance : undefined; return ( { const isSelected = (candidate: OverviewViewSelection): boolean => { if (!selection) return false; - if (candidate.kind !== selection.kind) return false; - return candidate.data.id === selection.data.id; + return candidate.kind === selection.kind && candidate.id === selection.id; }; return ( @@ -44,10 +43,10 @@ export const OverviewSelection = ({ instances, tunnels }: Props) => {

Instances

{instances.map((instance) => { - const value: OverviewViewSelection = { kind: 'instance', data: instance }; + const value: OverviewViewSelection = { kind: 'instance', id: instance.id }; return ( setSelection(value)} @@ -62,10 +61,10 @@ export const OverviewSelection = ({ instances, tunnels }: Props) => {

Tunnels

{tunnels.map((tunnel) => { - const value: OverviewViewSelection = { kind: 'tunnel', data: tunnel }; + const value: OverviewViewSelection = { kind: 'tunnel', id: tunnel.id }; return ( setSelection(value)} diff --git a/new-ui/src/pages/full/TunnelWizardPage/hooks/useTunnelWizardStore.tsx b/new-ui/src/pages/full/TunnelWizardPage/hooks/useTunnelWizardStore.tsx index 6fd26bd0..2567a0ac 100644 --- a/new-ui/src/pages/full/TunnelWizardPage/hooks/useTunnelWizardStore.tsx +++ b/new-ui/src/pages/full/TunnelWizardPage/hooks/useTunnelWizardStore.tsx @@ -3,35 +3,95 @@ import { TunnelWizardStep, type TunnelWizardStepValue } from '../types'; type StoreValues = { activeStep: TunnelWizardStepValue; + tunnelData: { + name: string; + pubkey: string; + prvkey: string; + address: string; + server_pubkey: string; + preshared_key: string; + allowed_ips?: string; + endpoint: string; + dns?: string; + persistent_keep_alive: number; + route_all_traffic: boolean; + pre_up?: string; + post_up?: string; + pre_down?: string; + post_down?: string; + }; }; const defaults: StoreValues = { activeStep: TunnelWizardStep.GeneralInformation, + tunnelData: { + name: '', + address: '', + endpoint: '', + persistent_keep_alive: 25, + preshared_key: '', + prvkey: '', + pubkey: '', + route_all_traffic: false, + server_pubkey: '', + allowed_ips: '', + dns: '', + post_down: '', + post_up: '', + pre_down: '', + pre_up: '', + }, +}; + +const nextStep = (step: TunnelWizardStepValue): TunnelWizardStepValue => { + switch (step) { + case TunnelWizardStep.GeneralInformation: + return TunnelWizardStep.Keys; + case TunnelWizardStep.Keys: + return TunnelWizardStep.VpnServer; + case TunnelWizardStep.VpnServer: + return TunnelWizardStep.AdvancedSettings; + case TunnelWizardStep.AdvancedSettings: + return TunnelWizardStep.Finish; + default: + return step; + } }; -const STEPS: TunnelWizardStepValue[] = [ - TunnelWizardStep.GeneralInformation, - TunnelWizardStep.Keys, - TunnelWizardStep.VpnServer, - TunnelWizardStep.AdvancedSettings, - TunnelWizardStep.Finish, -]; +const prevStep = (step: TunnelWizardStepValue): TunnelWizardStepValue => { + switch (step) { + case TunnelWizardStep.Keys: + return TunnelWizardStep.GeneralInformation; + case TunnelWizardStep.VpnServer: + return TunnelWizardStep.Keys; + case TunnelWizardStep.AdvancedSettings: + return TunnelWizardStep.VpnServer; + case TunnelWizardStep.Finish: + return TunnelWizardStep.AdvancedSettings; + default: + return step; + } +}; interface Store extends StoreValues { - next: () => void; - back: () => void; + next: (values?: Partial) => void; + back: (values?: Partial) => void; reset: () => void; } export const useTunnelWizardStore = create()((set, get) => ({ ...defaults, - next: () => { - const idx = STEPS.indexOf(get().activeStep); - if (idx < STEPS.length - 1) set({ activeStep: STEPS[idx + 1] }); + next: (tunnelData) => { + set({ + activeStep: nextStep(get().activeStep), + tunnelData: { ...get().tunnelData, ...tunnelData }, + }); }, - back: () => { - const idx = STEPS.indexOf(get().activeStep); - if (idx > 0) set({ activeStep: STEPS[idx - 1] }); + back: (tunnelData) => { + set({ + activeStep: prevStep(get().activeStep), + tunnelData: { ...get().tunnelData, ...tunnelData }, + }); }, reset: () => set(defaults), })); diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/AdvancedSettingsStep/AdvancedSettingsStep.tsx b/new-ui/src/pages/full/TunnelWizardPage/steps/AdvancedSettingsStep/AdvancedSettingsStep.tsx index 5ade7c89..51375de1 100644 --- a/new-ui/src/pages/full/TunnelWizardPage/steps/AdvancedSettingsStep/AdvancedSettingsStep.tsx +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/AdvancedSettingsStep/AdvancedSettingsStep.tsx @@ -1 +1,112 @@ -export const AdvancedSettingsStep = () =>
Advanced Settings
; +import { useMutation } from '@tanstack/react-query'; +import { useMemo } from 'react'; +import z from 'zod'; +import { Button } from '../../../../../shared/components/Button/Button'; +import { ButtonVariant } from '../../../../../shared/components/Button/types'; +import { Controls } from '../../../../../shared/components/Controls/Controls'; +import { Divider } from '../../../../../shared/components/Divider/Divider'; +import { SizedBox } from '../../../../../shared/components/SizedBox/SizedBox'; +import { Split } from '../../../../../shared/components/Split/Split'; +import { useAppForm } from '../../../../../shared/form'; +import { formChangeLogic } from '../../../../../shared/formLogic'; +import { api } from '../../../../../shared/rust-api/api'; +import { ThemeSpacing } from '../../../../../shared/types'; +import { useTunnelWizardStore } from '../../hooks/useTunnelWizardStore'; + +const formSchema = z.object({ + pre_up: z.string(), + post_up: z.string(), + pre_down: z.string(), + post_down: z.string(), +}); + +type FormFields = z.infer; + +export const AdvancedSettingsStep = () => { + const initData = useTunnelWizardStore((s) => s.tunnelData); + + const { mutateAsync } = useMutation({ mutationFn: api.saveTunnel }); + + const defaultValues = useMemo( + (): FormFields => ({ + pre_up: initData.pre_up ?? '', + post_up: initData.post_up ?? '', + pre_down: initData.pre_down ?? '', + post_down: initData.post_down ?? '', + }), + [initData.pre_up, initData.post_up, initData.pre_down, initData.post_down], + ); + + const form = useAppForm({ + defaultValues, + validationLogic: formChangeLogic, + validators: { + onSubmit: formSchema, + onChange: formSchema, + }, + onSubmit: async ({ value }) => { + const storeValues = useTunnelWizardStore.getState().tunnelData; + const toSend = { ...storeValues, ...value }; + await mutateAsync(toSend); + useTunnelWizardStore.getState().next(); + }, + }); + + return ( +
+
+

Advanced settings (optional)

+ +

+ Define optional shell commands to run before or after the tunnel interface is + brought up or down. Useful for custom routing rules, firewall adjustments, or + other network configuration. +

+
+ +
{ + e.stopPropagation(); + e.preventDefault(); + form.handleSubmit(); + }} + > + + + + {(field) => } + + + {(field) => } + + + + + + {(field) => } + + + {(field) => } + + + +
+ +
+ +
+ ); +}; diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/FinishStep.tsx b/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/FinishStep.tsx index ae8ed27e..ef3a656e 100644 --- a/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/FinishStep.tsx +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/FinishStep.tsx @@ -1 +1,40 @@ -export const FinishStep = () =>
Finish
; +import './style.scss'; + +import { useNavigate } from '@tanstack/react-router'; +import { Button } from '../../../../../shared/components/Button/Button'; +import { ButtonVariant } from '../../../../../shared/components/Button/types'; +import { useTunnelWizardStore } from '../../hooks/useTunnelWizardStore'; +import bannerSrc from './assets/banner.png'; + +export const FinishStep = () => { + const navigate = useNavigate(); + + return ( +
+
+ +
+

Your WireGuard tunnel added successfully

+

You can now connect this device, check its status and view statistics.

+
+
+
+ ); +}; diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/assets/banner.png b/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/assets/banner.png new file mode 100644 index 00000000..001974a7 Binary files /dev/null and b/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/assets/banner.png differ diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/style.scss b/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/style.scss new file mode 100644 index 00000000..acc13e86 --- /dev/null +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/FinishStep/style.scss @@ -0,0 +1,28 @@ +#finish-step { + .banner { + box-sizing: border-box; + padding-bottom: var(--spacing-3xl); + } + + h1 { + font: var(--t-h4); + color: var(--fg-white-100); + padding-bottom: var(--spacing-sm); + user-select: none; + } + + > p { + user-select: none; + font: var(--t-small-400); + color: var(--fg-white-70); + padding-bottom: var(--spacing-lg); + } + + > .actions { + display: flex; + flex-flow: row nowrap; + align-items: center; + justify-content: flex-start; + column-gap: var(--spacing-md); + } +} diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/GeneralInformationStep.tsx b/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/GeneralInformationStep.tsx index 367887bb..e686e748 100644 --- a/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/GeneralInformationStep.tsx +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/GeneralInformationStep.tsx @@ -1 +1,150 @@ -export const GeneralInformationStep = () =>
General Information
; +import './style.scss'; +import { useMutation } from '@tanstack/react-query'; +import { useNavigate } from '@tanstack/react-router'; +import { open } from '@tauri-apps/plugin-dialog'; +import { readFile } from '@tauri-apps/plugin-fs'; +import { useMemo, useRef } from 'react'; +import { Subject } from 'rxjs'; +import z from 'zod'; +import { Button } from '../../../../../shared/components/Button/Button'; +import { ButtonVariant } from '../../../../../shared/components/Button/types'; +import { Controls } from '../../../../../shared/components/Controls/Controls'; +import { Divider } from '../../../../../shared/components/Divider/Divider'; +import { IconKind } from '../../../../../shared/components/Icon'; +import { SizedBox } from '../../../../../shared/components/SizedBox/SizedBox'; +import { TooltipButton } from '../../../../../shared/components/TooltipButton/TooltipButton'; +import { useAppForm } from '../../../../../shared/form'; +import { formChangeLogic } from '../../../../../shared/formLogic'; +import { api } from '../../../../../shared/rust-api/api'; +import { ThemeSpacing } from '../../../../../shared/types'; +import { patternValidIp, patternValidIpV6 } from '../../../../../shared/utils/patterns'; +import { useTunnelWizardStore } from '../../hooks/useTunnelWizardStore'; + +const formSchema = z.object({ + name: z.string().trim().min(1, 'Field is required'), + address: z.string().refine((value) => { + if (value) { + const ips = value.split(',').map((ip) => ip.trim()); + return ips.every((ip) => patternValidIp.test(ip) || patternValidIpV6.test(ip)); + } + return false; + }, 'Field is invalid'), +}); + +type FormFields = z.infer; + +export const GeneralInformationStep = () => { + const navigate = useNavigate(); + const initData = useTunnelWizardStore((s) => s.tunnelData); + + const { mutate: importTunnelFile, isPending } = useMutation({ + mutationFn: async () => { + const filePath = await open({ + multiple: false, + directory: false, + filters: [{ name: 'wg-conf', extensions: ['conf', 'txt', 'config'] }], + }); + if (filePath) { + const decoder = new TextDecoder(); + const fileContents = await readFile(filePath); + const fileString = decoder.decode(fileContents); + const config = await api.parseTunnelConfig({ + filename: filePath, + config: fileString, + }); + const current = useTunnelWizardStore.getState().tunnelData; + useTunnelWizardStore.setState({ tunnelData: { ...current, ...config } }); + if (config.name) { + form.setFieldValue('name', config.name); + } + if (config.address) { + form.setFieldValue('address', config.address); + } + importTooltipRef.current.next(); + } + }, + }); + + const importTooltipRef = useRef(new Subject()); + + const defaultValues = useMemo( + (): FormFields => ({ + address: initData.address, + name: initData.name, + }), + [initData.address, initData.name], + ); + + const form = useAppForm({ + defaultValues, + validationLogic: formChangeLogic, + validators: { + onSubmit: formSchema, + onChange: formSchema, + }, + onSubmit: ({ value }) => { + useTunnelWizardStore.getState().next(value); + }, + }); + + return ( +
+
+

General information

+ +

{`Upload your config file (optional) and we'll securely extract the connection settings for you. This is the fastest and recommended way to get started.`}

+
+ { + importTunnelFile(); + }, + }} + tooltipText="Config file applied" + tooltipTrigger={importTooltipRef.current} + /> +
+
+ +
{ + e.stopPropagation(); + e.preventDefault(); + form.handleSubmit(); + }} + > + + + {(field) => } + + + + {(field) => } + + +
+ +
+ +
+ ); +}; diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/style.scss b/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/style.scss new file mode 100644 index 00000000..d4a26a02 --- /dev/null +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/GeneralInformationStep/style.scss @@ -0,0 +1,7 @@ +#general-info-step { + header { + .actions { + padding-top: var(--spacing-xl); + } + } +} diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/KeysStep.tsx b/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/KeysStep.tsx index d3298b2b..b32b7c98 100644 --- a/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/KeysStep.tsx +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/KeysStep.tsx @@ -1 +1,109 @@ -export const KeysStep = () =>
Keys
; +import './style.scss'; +import { useMemo } from 'react'; +import z from 'zod'; +import { Button } from '../../../../../shared/components/Button/Button'; +import { ButtonVariant } from '../../../../../shared/components/Button/types'; +import { Controls } from '../../../../../shared/components/Controls/Controls'; +import { SizedBox } from '../../../../../shared/components/SizedBox/SizedBox'; +import { TooltipButton } from '../../../../../shared/components/TooltipButton/TooltipButton'; +import { useAppForm } from '../../../../../shared/form'; +import { formChangeLogic } from '../../../../../shared/formLogic'; +import { ThemeSpacing } from '../../../../../shared/types'; +import { generateWGKeys } from '../../../../../shared/utils/generateWGKeys'; +import { patternValidWireguardKey } from '../../../../../shared/utils/patterns'; +import { useTunnelWizardStore } from '../../hooks/useTunnelWizardStore'; + +const formSchema = z.object({ + prvkey: z + .string() + .refine((v) => patternValidWireguardKey.test(v), 'Invalid WireGuard key'), + pubkey: z + .string() + .refine((v) => patternValidWireguardKey.test(v), 'Invalid WireGuard key'), +}); + +type FormFields = z.infer; + +export const KeysStep = () => { + const initData = useTunnelWizardStore((s) => s.tunnelData); + + const defaultValues = useMemo( + (): FormFields => ({ + prvkey: initData.prvkey, + pubkey: initData.pubkey, + }), + [initData.prvkey, initData.pubkey], + ); + + const form = useAppForm({ + defaultValues, + validationLogic: formChangeLogic, + validators: { + onSubmit: formSchema, + onChange: formSchema, + }, + onSubmit: ({ value }) => { + useTunnelWizardStore.getState().next(value); + }, + }); + + return ( +
+
+

Keys

+ +

{`Upload your config file (optional) and we'll securely extract the connection settings for you. This is the fastest and recommended way to get started.`}

+
+ +
{ + e.stopPropagation(); + e.preventDefault(); + form.handleSubmit(); + }} + > + + + {(field) => } + + + + {(field) => } + + +
+ { + const pair = generateWGKeys(); + form.setFieldValue('prvkey', pair.privateKey); + form.setFieldValue('pubkey', pair.publicKey); + }, + }} + tooltipText="New keys set" + /> +
+
+
+ +
+ +
+ ); +}; diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/style.scss b/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/style.scss new file mode 100644 index 00000000..385f8b46 --- /dev/null +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/KeysStep/style.scss @@ -0,0 +1,9 @@ +#keys-step { + .actions { + width: 100%; + display: flex; + flex-flow: row nowrap; + align-items: center; + justify-content: flex-end; + } +} diff --git a/new-ui/src/pages/full/TunnelWizardPage/steps/VpnServerStep/VpnServerStep.tsx b/new-ui/src/pages/full/TunnelWizardPage/steps/VpnServerStep/VpnServerStep.tsx index 603fe379..e65d9af0 100644 --- a/new-ui/src/pages/full/TunnelWizardPage/steps/VpnServerStep/VpnServerStep.tsx +++ b/new-ui/src/pages/full/TunnelWizardPage/steps/VpnServerStep/VpnServerStep.tsx @@ -1 +1,136 @@ -export const VpnServerStep = () =>
VPN Server
; +import { useMemo } from 'react'; +import z from 'zod'; +import { Button } from '../../../../../shared/components/Button/Button'; +import { ButtonVariant } from '../../../../../shared/components/Button/types'; +import { Controls } from '../../../../../shared/components/Controls/Controls'; +import { SizedBox } from '../../../../../shared/components/SizedBox/SizedBox'; +import { Split } from '../../../../../shared/components/Split/Split'; +import { useAppForm } from '../../../../../shared/form'; +import { formChangeLogic } from '../../../../../shared/formLogic'; +import { ThemeSpacing } from '../../../../../shared/types'; +import { + cidrRegex, + patternValidEndpoint, + patternValidWireguardKey, +} from '../../../../../shared/utils/patterns'; +import { useTunnelWizardStore } from '../../hooks/useTunnelWizardStore'; + +const formSchema = z.object({ + server_pubkey: z + .string() + .refine((v) => patternValidWireguardKey.test(v), 'Invalid WireGuard key'), + preshared_key: z + .string() + .refine((v) => !v || patternValidWireguardKey.test(v), 'Invalid WireGuard key'), + endpoint: z.string().refine((v) => patternValidEndpoint.test(v), 'Invalid address'), + dns: z.string(), + allowed_ips: z.string().refine((v) => { + if (!v) return true; + return v + .split(',') + .map((s) => s.trim()) + .every((cidr) => cidrRegex.test(cidr)); + }, 'Invalid CIDR notation'), + persistent_keep_alive: z.number().int().min(0), +}); + +type FormFields = z.infer; + +export const VpnServerStep = () => { + const initData = useTunnelWizardStore((s) => s.tunnelData); + + const defaultValues = useMemo( + (): FormFields => ({ + server_pubkey: initData.server_pubkey, + preshared_key: initData.preshared_key, + endpoint: initData.endpoint, + dns: initData.dns ?? '', + allowed_ips: initData.allowed_ips ?? '', + persistent_keep_alive: initData.persistent_keep_alive, + }), + [ + initData.server_pubkey, + initData.preshared_key, + initData.endpoint, + initData.dns, + initData.allowed_ips, + initData.persistent_keep_alive, + ], + ); + + const form = useAppForm({ + defaultValues, + validationLogic: formChangeLogic, + validators: { + onSubmit: formSchema, + onChange: formSchema, + }, + onSubmit: ({ value }) => { + useTunnelWizardStore.getState().next(value); + }, + }); + + return ( +
+
+

VPN Server

+ +

{`Upload your config file (optional) and we'll securely extract the connection settings for you. This is the fastest and recommended way to get started.`}

+
+ +
{ + e.stopPropagation(); + e.preventDefault(); + form.handleSubmit(); + }} + > + + + + {(field) => } + + + {(field) => } + + + + + + {(field) => } + + + {(field) => } + + + + + {(field) => ( + + )} + + + + {(field) => } + + +
+ +
+ + + ); +}; diff --git a/new-ui/src/routes/compact/index.tsx b/new-ui/src/routes/compact/index.tsx index 5f302f0b..566d5a42 100644 --- a/new-ui/src/routes/compact/index.tsx +++ b/new-ui/src/routes/compact/index.tsx @@ -29,18 +29,18 @@ export const Route = createFileRoute('/compact/')({ if (stored === null) { storedIsValid = false; } else if (stored.kind === 'instance') { - storedIsValid = instances.some((i) => i.id === stored.data.id); + storedIsValid = instances.some((i) => i.id === stored.id); } else { - storedIsValid = tunnels.some((t) => t.id === stored.data.id); + storedIsValid = tunnels.some((t) => t.id === stored.id); } let selected: OverviewViewSelection; if (storedIsValid && stored !== null) { selected = stored; } else if (instances.length > 0) { - selected = { kind: 'instance', data: instances[0] }; + selected = { kind: 'instance', id: instances[0].id }; } else { - selected = { kind: 'tunnel', data: tunnels[0] }; + selected = { kind: 'tunnel', id: tunnels[0].id }; } if (!storedIsValid) { @@ -51,7 +51,7 @@ export const Route = createFileRoute('/compact/')({ let locations: LocationInfo[]; if (selected.kind === 'instance') { locations = await context.queryClient.fetchQuery( - getLocationsQueryOptions(selected.data.id), + getLocationsQueryOptions(selected.id), ); } else { locations = []; diff --git a/new-ui/src/routes/full/_default/overview.tsx b/new-ui/src/routes/full/_default/overview.tsx index 4c434138..c083fd44 100644 --- a/new-ui/src/routes/full/_default/overview.tsx +++ b/new-ui/src/routes/full/_default/overview.tsx @@ -27,16 +27,16 @@ export const Route = createFileRoute('/full/_default/overview')({ if (stored === null) { storedIsValid = false; } else if (stored.kind === 'instance') { - storedIsValid = instances.some((i) => i.id === stored.data.id); + storedIsValid = instances.some((i) => i.id === stored.id); } else { - storedIsValid = tunnels.some((t) => t.id === stored.data.id); + storedIsValid = tunnels.some((t) => t.id === stored.id); } if (!storedIsValid) { const selected = instances.length > 0 - ? { kind: 'instance' as const, data: instances[0] } - : { kind: 'tunnel' as const, data: tunnels[0] }; + ? { kind: 'instance' as const, id: instances[0].id } + : { kind: 'tunnel' as const, id: tunnels[0].id }; await api.patchSessionState({ view_selection: selected }); await context.queryClient.invalidateQueries({ queryKey: ['session-state'] }); } diff --git a/new-ui/src/shared/components/Split/Split.tsx b/new-ui/src/shared/components/Split/Split.tsx new file mode 100644 index 00000000..daea8804 --- /dev/null +++ b/new-ui/src/shared/components/Split/Split.tsx @@ -0,0 +1,22 @@ +import type { CSSProperties, PropsWithChildren } from 'react'; +import { ThemeSpacing, type ThemeSpacingValue } from '../../types'; + +type Props = PropsWithChildren<{ + split?: number; + spacing?: ThemeSpacingValue; +}>; + +export const Split = ({ children, split = 2, spacing = ThemeSpacing.Sm }: Props) => { + const style: CSSProperties = { + display: 'grid', + gridTemplateColumns: `repeat(${split}, 1fr)`, + columnGap: spacing, + width: '100%', + }; + + return ( +
+ {children} +
+ ); +}; diff --git a/new-ui/src/shared/components/wizard/WizardPage/WizardPage.tsx b/new-ui/src/shared/components/wizard/WizardPage/WizardPage.tsx index 8fb4d8de..3e5537da 100644 --- a/new-ui/src/shared/components/wizard/WizardPage/WizardPage.tsx +++ b/new-ui/src/shared/components/wizard/WizardPage/WizardPage.tsx @@ -1,4 +1,10 @@ -import { type HTMLProps, type PropsWithChildren, Suspense, useMemo } from 'react'; +import { + Fragment, + type HTMLProps, + type PropsWithChildren, + Suspense, + useMemo, +} from 'react'; import './style.scss'; import clsx from 'clsx'; import { sort } from 'radashi'; @@ -47,10 +53,14 @@ export const WizardPage = ({
-
-

{`Step ${activeStepIndex + 1} of ${visibleSteps.length}`}

-
- + {activeStepIndex !== visibleSteps.length - 1 && ( + +
+

{`Step ${activeStepIndex + 1} of ${visibleSteps.length}`}

+
+ +
+ )} }>{children}
diff --git a/new-ui/src/shared/rust-api/api.ts b/new-ui/src/shared/rust-api/api.ts index ab6f7589..2ccb9e06 100644 --- a/new-ui/src/shared/rust-api/api.ts +++ b/new-ui/src/shared/rust-api/api.ts @@ -74,11 +74,10 @@ const getTunnels = (): Promise => invoke(TauriCommand.AllTunnels const getTunnelDetails = (tunnelId: number): Promise => invoke(TauriCommand.TunnelDetails, { tunnelId }); -const parseTunnelConfig = ( - filename: string, - config: string, -): Promise> => - invoke(TauriCommand.ParseTunnelConfig, { filename, config }); +const parseTunnelConfig = (data: { + filename: string; + config: string; +}): Promise> => invoke(TauriCommand.ParseTunnelConfig, data); const saveTunnel = (tunnel: TunnelRequest): Promise => invoke(TauriCommand.SaveTunnel, { tunnel }); diff --git a/new-ui/src/shared/rust-api/types.ts b/new-ui/src/shared/rust-api/types.ts index 3912cbe7..b7823393 100644 --- a/new-ui/src/shared/rust-api/types.ts +++ b/new-ui/src/shared/rust-api/types.ts @@ -368,9 +368,10 @@ export type SetLocationMfaMethodArgs = { mfaMethod: MfaMethodValue; }; -export type OverviewViewSelection = - | { kind: 'instance'; data: InstanceInfo } - | { kind: 'tunnel'; data: LocationInfo }; +export type OverviewViewSelection = { + kind: 'instance' | 'tunnel'; + id: number; +}; /** Mirrors `SessionState` in src/session_state.rs. Fields are snake_case (raw serde output). */ export type SessionState = { diff --git a/new-ui/src/shared/utils/generateWGKeys copy.ts b/new-ui/src/shared/utils/generateWGKeys copy.ts new file mode 100644 index 00000000..6e981855 --- /dev/null +++ b/new-ui/src/shared/utils/generateWGKeys copy.ts @@ -0,0 +1,9 @@ +import { encode } from '@stablelib/base64'; +import { generateKeyPair } from '@stablelib/x25519'; + +export const generateWGKeys = () => { + const keys = generateKeyPair(); + const publicKey = encode(keys.publicKey); + const privateKey = encode(keys.secretKey); + return { publicKey, privateKey }; +}; diff --git a/new-ui/src/shared/utils/patterns.ts b/new-ui/src/shared/utils/patterns.ts new file mode 100644 index 00000000..0dc65fbb --- /dev/null +++ b/new-ui/src/shared/utils/patterns.ts @@ -0,0 +1,85 @@ +/* eslint-disable no-useless-escape */ +export const patternNoSpecialChars = /^\w+$/; + +export const patternDigitOrLowercase = /^[0-9a-z]+$/g; + +export const patternValidEmail = + // eslint-disable-next-line max-len + /[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?/g; + +export const patternAtLeastOneUpperCaseChar = /(?=.*?[A-Z])/g; + +export const patternAtLeastOneLowerCaseChar = /(?=.*?[a-z])/g; + +export const patternAtLeastOneDigit = /(?=.*?[0-9])/g; + +export const patternStartsWithDigit = /^\d/; + +export const patternAtLeastOneSpecialChar = /(?=.*?[#?!@$%^&*-])/g; + +export const patternValidPhoneNumber = + /^(\+?\d{1,3}\s?)?(\(\d{1,3}\)|\d{1,3})[-\s]?\d{1,4}[-\s]?\d{1,4}?$/; + +export const patternValidWireguardKey = + /^[A-Za-z0-9+/]{42}[A|E|I|M|Q|U|Y|c|g|k|o|s|w|4|8|0]=$/; + +export const patternBaseUrl = /:\/\/(.[^/]+)/; + +// https://gist.github.com/dperini/729294 +export const patternValidUrl = new RegExp( + '^' + + // protocol identifier (optional) + // short syntax // still required + '(?:(?:(?:https?):)?\\/\\/)' + + // user:pass BasicAuth (optional) + '(?:\\S+(?::\\S*)?@)?' + + '(?:' + + // IP address dotted notation octets + // excludes loopback network 0.0.0.0 + // excludes reserved space >= 224.0.0.0 + // excludes network & broadcast addresses + // (first & last IP address of each class) + '(?:[1-9]\\d?|1\\d\\d|2[01]\\d|22[0-3])' + + '(?:\\.(?:1?\\d{1,2}|2[0-4]\\d|25[0-5])){2}' + + '(?:\\.(?:[1-9]\\d?|1\\d\\d|2[0-4]\\d|25[0-4]))' + + '|' + + // host & domain names, may end with dot + // can be replaced by a shortest alternative + // (?![-_])(?:[-\\w\\u00a1-\\uffff]{0,63}[^-_]\\.)+ + '(?:' + + '(?:' + + '[a-z0-9\\u00a1-\\uffff]' + + '[a-z0-9\\u00a1-\\uffff_-]{0,62}' + + ')?' + + '[a-z0-9\\u00a1-\\uffff]\\.' + + ')+' + + // TLD identifier name, may end with dot + '(?:[a-z\\u00a1-\\uffff]{2,}\\.?)' + + ')' + + // port number (optional) + '(?::\\d{2,5})?' + + // resource path (optional) + '(?:[/?#]\\S*)?' + + '$', + 'i', +); + +export const patternValidDomain = + /^(?:(?:(?:[a-zA-z-]+):\/{1,3})?(?:[a-zA-Z0-9])(?:[a-zA-Z0-9\-.]){1,61}(?:\.[a-zA-Z]{2,})+|\[(?:(?:(?:[a-fA-F0-9]){1,4})(?::(?:[a-fA-F0-9]){1,4}){7}|::1|::)\]|(?:(?:[0-9]{1,3})(?:\.[0-9]{1,3}){3}))(?::[0-9]{1,5})?$/; + +export const patternValidIp = + /^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(?:\/32)?$/; + +export const cidrRegex = + /^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\/\d{1,2}|[0-9a-fA-F:.]+\/\d{1,3})$/; +// Regular expression to match IPv4, IPv6, domain name, or localhost with port +export const patternValidEndpoint = + /^(localhost|\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b|\b(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}\b)(?::(\d+))?$/; + +// Copied from zod source code and added optional mask at the end to match WireguardRequirements +export const patternValidIpV6 = + /^(([a-f0-9]{1,4}:){7}|::([a-f0-9]{1,4}:){0,6}|([a-f0-9]{1,4}:){1}:([a-f0-9]{1,4}:){0,5}|([a-f0-9]{1,4}:){2}:([a-f0-9]{1,4}:){0,4}|([a-f0-9]{1,4}:){3}:([a-f0-9]{1,4}:){0,3}|([a-f0-9]{1,4}:){4}:([a-f0-9]{1,4}:){0,2}|([a-f0-9]{1,4}:){5}:([a-f0-9]{1,4}:){0,1})([a-f0-9]{1,4}|(((25[0-5])|(2[0-4][0-9])|(1[0-9]{2})|([0-9]{1,2}))\.){3}((25[0-5])|(2[0-4][0-9])|(1[0-9]{2})|([0-9]{1,2})))(?:\/128)?$/; + +// Reuse pattern from above to support format [ipv6]:port +export const patternValidIpV6WithPort = + /^\[((([a-f0-9]{1,4}:){7}|::([a-f0-9]{1,4}:){0,6}|([a-f0-9]{1,4}:){1}:([a-f0-9]{1,4}:){0,5}|([a-f0-9]{1,4}:){2}:([a-f0-9]{1,4}:){0,4}|([a-f0-9]{1,4}:){3}:([a-f0-9]{1,4}:){0,3}|([a-f0-9]{1,4}:){4}:([a-f0-9]{1,4}:){0,2}|([a-f0-9]{1,4}:){5}:([a-f0-9]{1,4}:){0,1})([a-f0-9]{1,4}|(((25[0-5])|(2[0-4][0-9])|(1[0-9]{2})|([0-9]{1,2}))\.){3}((25[0-5])|(2[0-4][0-9])|(1[0-9]{2})|([0-9]{1,2})))(\/128)?)\]:(\d{1,5})$/; diff --git a/src-tauri/core/src/events.rs b/src-tauri/core/src/events.rs index cc8802a2..71f859e3 100644 --- a/src-tauri/core/src/events.rs +++ b/src-tauri/core/src/events.rs @@ -16,6 +16,7 @@ pub enum EventKey { UuidMismatch, WindowSwapped, SessionStateChanged, + InstanceUpdated, } impl From for &'static str { @@ -35,6 +36,7 @@ impl From for &'static str { EventKey::UuidMismatch => "uuid-mismatch", EventKey::WindowSwapped => "window-swapped", EventKey::SessionStateChanged => "session-state-changed", + EventKey::InstanceUpdated => "instance-updated", } } } diff --git a/src-tauri/enterprise/config-sync/src/commands.rs b/src-tauri/enterprise/config-sync/src/commands.rs index 7eb4a472..f74fc8b4 100644 --- a/src-tauri/enterprise/config-sync/src/commands.rs +++ b/src-tauri/enterprise/config-sync/src/commands.rs @@ -49,7 +49,7 @@ pub async fn do_update_instance( transaction: &mut Transaction<'_, Sqlite>, instance: &mut Instance, response: DeviceConfigResponse, -) -> Result<(), Error> { +) -> Result { debug!("Updating instance {instance}"); let locations_changed_val = locations_changed(transaction, instance, &response).await?; let instance_info = response @@ -232,7 +232,7 @@ pub async fn do_update_instance( } } - Ok(()) + Ok(locations_changed_val) } pub async fn disable_enterprise_features<'e, E>( diff --git a/src-tauri/src/bin/defguard-client.rs b/src-tauri/src/bin/defguard-client.rs index f8b1687e..d73a05ef 100644 --- a/src-tauri/src/bin/defguard-client.rs +++ b/src-tauri/src/bin/defguard-client.rs @@ -24,7 +24,7 @@ use defguard_client::{ DB_POOL, }, enterprise::provisioning::handle_client_initialization, - events::handle_deep_link, + events::{handle_deep_link, EventKey}, periodic::run_periodic_tasks, service, session_state, tray::{configure_tray_icon, setup_tray}, @@ -34,7 +34,7 @@ use defguard_client::{ }; use defguard_client_core::connection::active_connections::close_all_connections; use log::{Level, LevelFilter}; -use tauri::{async_runtime, AppHandle, Builder, Manager, RunEvent, WindowEvent}; +use tauri::{async_runtime, AppHandle, Builder, Emitter, Listener, Manager, RunEvent, WindowEvent}; use tauri_plugin_deep_link::DeepLinkExt; use tauri_plugin_log::{Target, TargetKind}; @@ -48,6 +48,35 @@ const LOGGING_TARGET_IGNORE_LIST: [&str; 5] = ["tauri", "sqlx", "hyper", "h2", " static LOG_INCLUDES: LazyLock> = LazyLock::new(load_log_targets); async fn startup(app_handle: &AppHandle) { + // When instance locations change, re-validate MFA preferences and update session state. + { + let handle = app_handle.clone(); + app_handle.listen(Into::<&'static str>::into(EventKey::InstanceUpdated), move |_| { + let handle = handle.clone(); + async_runtime::spawn(async move { + let app_state = handle.state::(); + let preference = match app_state.session_state.lock() { + Ok(guard) => guard.location_mfa_preference.clone(), + Err(err) => { + error!("Session state mutex poisoned during MFA preference validation: {err}"); + return; + } + }; + match session_state::validate_location_mfa_preference(&DB_POOL, preference).await { + Ok(validated) => { + if let Ok(mut guard) = app_state.session_state.lock() { + guard.location_mfa_preference = validated; + } + if let Err(err) = handle.emit(EventKey::SessionStateChanged.into(), ()) { + error!("Failed to emit session-state-changed after MFA preference validation: {err}"); + } + } + Err(err) => error!("Failed to validate location MFA preference: {err}"), + } + }); + }); + } + debug!("Purging old stats from the database."); if let Err(err) = LocationStats::purge(&*DB_POOL).await { error!("Failed to purge location stats: {err}"); diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 06968dc1..494f9476 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -660,9 +660,15 @@ pub async fn update_instance( if let Some(mut instance) = Instance::find_by_id(&*DB_POOL, instance_id).await? { debug!("The instance with id {instance_id} to update was found: {instance}"); let mut transaction = DB_POOL.begin().await?; - do_update_instance(&mut transaction, &mut instance, response).await?; + let locations_changed = + do_update_instance(&mut transaction, &mut instance, response).await?; transaction.commit().await?; + if locations_changed { + if let Err(err) = app_handle.emit(EventKey::InstanceUpdated.into(), ()) { + error!("Failed to emit instance-updated event: {err}"); + } + } app_handle .emit(EventKey::InstanceUpdate.into(), ()) .map_err(crate::tauri_err_to_app_err)?; @@ -1006,7 +1012,7 @@ pub async fn save_tunnel(tunnel: Tunnel, handle: AppHandle) -> Result<(), Ok(()) } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct TunnelInfo { pub id: I, pub name: String, diff --git a/src-tauri/src/enterprise/periodic/config.rs b/src-tauri/src/enterprise/periodic/config.rs index 303c1773..4a0b48ae 100644 --- a/src-tauri/src/enterprise/periodic/config.rs +++ b/src-tauri/src/enterprise/periodic/config.rs @@ -164,11 +164,17 @@ pub async fn poll_instance( "Updating instance {}({}) configuration: {device_config:?}", instance.name, instance.id, ); - do_update_instance(transaction, instance, device_config.clone()).await?; + let locations_changed = + do_update_instance(transaction, instance, device_config.clone()).await?; info!( "Updated instance {}({}) configuration based on core's response", instance.name, instance.id ); + if locations_changed { + if let Err(err) = handle.emit(EventKey::InstanceUpdated.into(), ()) { + error!("Failed to emit instance-updated event: {err}"); + } + } } else { debug!( "Emitting config-changed event for instance {}({})", diff --git a/src-tauri/src/session_state.rs b/src-tauri/src/session_state.rs index 3866ce46..6f6c1e5d 100644 --- a/src-tauri/src/session_state.rs +++ b/src-tauri/src/session_state.rs @@ -5,22 +5,27 @@ use struct_patch::Patch; use tauri::{AppHandle, Emitter, Manager, State}; use defguard_client_core::{ - database::models::{instance::InstanceInfo, location::LocationMfaMethod, Id}, + database::models::location::{LocationMfaMethod, LocationMfaMode}, events::EventKey, }; use crate::{ appstate::AppState, - commands::LocationInfo, - database::{models::location::Location, DB_POOL}, + database::{models::location::Location, DbPool, DB_POOL}, error::Error, }; #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] -#[serde(tag = "kind", content = "data", rename_all = "lowercase")] -pub enum OverviewViewSelection { - Instance(InstanceInfo), - Tunnel(LocationInfo), +#[serde(rename_all = "lowercase")] +pub enum ViewSelectionKind { + Instance, + Tunnel, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +pub struct OverviewViewSelection { + pub kind: ViewSelectionKind, + pub id: i64, } #[derive(Clone, Debug, Default, Deserialize, Patch, Serialize)] @@ -30,6 +35,59 @@ pub struct SessionState { pub location_mfa_preference: HashMap, } +// verifies integrity of the mfa preference in session, this needs to be ran on location update so polling doesn't break the session store +pub async fn validate_location_mfa_preference( + pool: &DbPool, + mut preference: HashMap, +) -> Result, Error> { + if preference.is_empty() { + return Ok(preference); + } + + let ids: Vec = preference.keys().filter_map(|k| k.parse().ok()).collect(); + + if ids.is_empty() { + preference.clear(); + return Ok(preference); + } + + let placeholders = (0..ids.len()).map(|_| "?").collect::>().join(","); + let sql = format!("SELECT id, location_mfa_mode FROM location WHERE id IN ({placeholders})"); + let mut q = sqlx::query_as::<_, (i64, LocationMfaMode)>(&sql); + for id in &ids { + q = q.bind(*id); + } + let rows = q.fetch_all(pool).await?; + + let mut found = std::collections::HashSet::with_capacity(rows.len()); + for (id, mfa_mode) in rows { + let key = id.to_string(); + found.insert(key.clone()); + match mfa_mode { + LocationMfaMode::Disabled => { + preference.remove(&key); + } + LocationMfaMode::External => { + preference.insert(key, LocationMfaMethod::Oidc); + } + LocationMfaMode::Internal => { + if let Some(m) = preference.get(&key) { + match m { + LocationMfaMethod::Totp + | LocationMfaMethod::Email + | LocationMfaMethod::MobileApprove => {} + _ => { + preference.insert(key, LocationMfaMethod::Totp); + } + } + } + } + } + } + preference.retain(|k, _| found.contains(k)); + Ok(preference) +} + pub async fn initialize_session_state() -> Result { let locations = Location::all(&*DB_POOL, false).await?; let location_mfa_preference = locations