diff --git a/ee/presencedetection/presencedetection_windows.go b/ee/presencedetection/presencedetection_windows.go index c7daea92b..035c45fa9 100644 --- a/ee/presencedetection/presencedetection_windows.go +++ b/ee/presencedetection/presencedetection_windows.go @@ -6,6 +6,7 @@ package presencedetection import ( "errors" "fmt" + "sync" "syscall" "time" "unsafe" @@ -106,27 +107,16 @@ type KeyCredentialAttestationResultVTable struct { GetStatus uintptr } +var roInitialize = sync.OnceFunc(func() { + ole.RoInitialize(1) +}) + // Register creates a credential under the given name for the given user. func Register(credentialName string) error { - if err := ole.RoInitialize(1); err != nil { - return fmt.Errorf("initializing: %w", err) - } - - // Get access to the KeyCredentialManager - factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable) - if err != nil { - return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err) - } - defer factory.Release() - managerObj, err := factory.QueryInterface(keyCredentialManagerGuid) - if err != nil { - return fmt.Errorf("getting KeyCredentialManager from factory: %w", err) - } - defer managerObj.Release() - keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj)) + roInitialize() // Check to see if Hello is an option - isHelloSupported, err := isSupported(keyCredentialManager) + isHelloSupported, err := isSupported() if err != nil { return fmt.Errorf("determining whether Hello is supported: %w", err) } @@ -135,7 +125,7 @@ func Register(credentialName string) error { } // Create a credential that will be tied to the current user and this application - if err := register(keyCredentialManager, credentialName); err != nil { + if err := register(credentialName); err != nil { return fmt.Errorf("creating credential: %w", err) } @@ -144,10 +134,20 @@ func Register(credentialName string) error { // Detect prompts the user via Hello. func Detect(_ string, credentialName string) (bool, error) { - if err := ole.RoInitialize(1); err != nil { - return false, fmt.Errorf("initializing: %w", err) + roInitialize() + + // Create a credential that will be tied to the current user and this application + if err := authenticate(credentialName); err != nil { + return false, fmt.Errorf("authenticating with credential: %w", err) } + return true, nil +} + +// isSupported calls Windows.Security.Credentials.KeyCredentialManager.IsSupportedAsync. +// It determines whether the current device and user is capable of provisioning a key credential. +// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.issupportedasync?view=winrt-26100 +func isSupported() (bool, error) { // Get access to the KeyCredentialManager factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable) if err != nil { @@ -161,18 +161,6 @@ func Detect(_ string, credentialName string) (bool, error) { defer managerObj.Release() keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj)) - // Create a credential that will be tied to the current user and this application - if err := authenticate(keyCredentialManager, credentialName); err != nil { - return false, fmt.Errorf("authenticating with credential: %w", err) - } - - return true, nil -} - -// isSupported calls Windows.Security.Credentials.KeyCredentialManager.IsSupportedAsync. -// It determines whether the current device and user is capable of provisioning a key credential. -// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.issupportedasync?view=winrt-26100 -func isSupported(keyCredentialManager *KeyCredentialManager) (bool, error) { var isSupportedAsyncOperation *foundation.IAsyncOperation ret, _, _ := syscall.SyscallN( keyCredentialManager.VTable().IsSupportedAsync, @@ -212,7 +200,20 @@ func isSupported(keyCredentialManager *KeyCredentialManager) (bool, error) { // register calls Windows.Security.Credentials.KeyCredentialManager.RequestCreateAsync. // It creates a new key credential for the current user and application. // See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.requestcreateasync?view=winrt-26100 -func register(keyCredentialManager *KeyCredentialManager, credentialName string) error { +func register(credentialName string) error { + // Get access to the KeyCredentialManager + factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable) + if err != nil { + return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err) + } + defer factory.Release() + managerObj, err := factory.QueryInterface(keyCredentialManagerGuid) + if err != nil { + return fmt.Errorf("getting KeyCredentialManager from factory: %w", err) + } + defer managerObj.Release() + keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj)) + credentialNameHString, err := ole.NewHString(credentialName) if err != nil { return fmt.Errorf("creating credential name hstring: %w", err) @@ -285,11 +286,10 @@ func register(keyCredentialManager *KeyCredentialManager, credentialName string) // For now, we retrieve but do not return/store the pubkey and attestation. In the future // we may want to store these. - credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) - if _, err := getPubkey(credential); err != nil { + if _, err := getPubkey(keyCredentialObj); err != nil { return fmt.Errorf("getting pubkey from credential: %w", err) } - if _, err := getAttestation(credential); err != nil { + if _, err := getAttestation(keyCredentialObj); err != nil { return fmt.Errorf("getting attestation from credential: %w", err) } @@ -299,7 +299,20 @@ func register(keyCredentialManager *KeyCredentialManager, credentialName string) // authenticate calls Windows.Security.Credentials.KeyCredentialManager.OpenAsync. // It retrieves the key credential stored under `credentialName` for the given user and application. // See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.openasync?view=winrt-26100 -func authenticate(keyCredentialManager *KeyCredentialManager, credentialName string) error { +func authenticate(credentialName string) error { + // Get access to the KeyCredentialManager + factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable) + if err != nil { + return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err) + } + defer factory.Release() + managerObj, err := factory.QueryInterface(keyCredentialManagerGuid) + if err != nil { + return fmt.Errorf("getting KeyCredentialManager from factory: %w", err) + } + defer managerObj.Release() + keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj)) + credentialNameHString, err := ole.NewHString(credentialName) if err != nil { return fmt.Errorf("creating credential name hstring: %w", err) @@ -371,11 +384,10 @@ func authenticate(keyCredentialManager *KeyCredentialManager, credentialName str // For now, we retrieve but do not return/store the pubkey and attestation. In the future // we may want to store these. - credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) - if _, err := getPubkey(credential); err != nil { + if _, err := getPubkey(keyCredentialObj); err != nil { return fmt.Errorf("getting pubkey from credential: %w", err) } - if _, err := getAttestation(credential); err != nil { + if _, err := getAttestation(keyCredentialObj); err != nil { return fmt.Errorf("getting attestation from credential: %w", err) } @@ -385,7 +397,9 @@ func authenticate(keyCredentialManager *KeyCredentialManager, credentialName str // getPubkey calls Windows.Security.Credentials.KeyCredential.RetrievePubkey. // It returns the pubkey for the given key credential. // See https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.retrievepublickey?view=winrt-26100. -func getPubkey(credential *KeyCredential) ([]byte, error) { +func getPubkey(keyCredentialObj *ole.IDispatch) ([]byte, error) { + credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) + var pubkeyBufferPointer unsafe.Pointer retrievePubKeyReturn, _, _ := syscall.SyscallN( credential.VTable().RetrievePublicKeyWithDefaultBlobType, @@ -422,7 +436,9 @@ func getPubkey(credential *KeyCredential) ([]byte, error) { // getAttestation calls Windows.Security.Credentials.KeyCredential.GetAttestationAsync. // It gets an attestation for a key credential. // See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.getattestationasync?view=winrt-26100 -func getAttestation(credential *KeyCredential) ([]byte, error) { +func getAttestation(keyCredentialObj *ole.IDispatch) ([]byte, error) { + credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj)) + var getAttestationAsyncOperation *foundation.IAsyncOperation getAttestationReturn, _, _ := syscall.SyscallN( credential.VTable().GetAttestationAsync,