Skip to content

Commit

Permalink
Call initialize only once, set up manager reference fresh each time
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Sep 23, 2024
1 parent 1fa02bf commit 262fce2
Showing 1 changed file with 58 additions and 42 deletions.
100 changes: 58 additions & 42 deletions ee/presencedetection/presencedetection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package presencedetection
import (
"errors"
"fmt"
"sync"
"syscall"
"time"
"unsafe"
Expand Down Expand Up @@ -106,27 +107,16 @@ type KeyCredentialAttestationResultVTable struct {
GetStatus uintptr
}

var roInitialize = sync.OnceFunc(func() {
ole.RoInitialize(1)
})

// Register creates a credential under the given name for the given user.
func Register(credentialName string) error {
if err := ole.RoInitialize(1); err != nil {
return fmt.Errorf("initializing: %w", err)
}

// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
}
defer factory.Release()
managerObj, err := factory.QueryInterface(keyCredentialManagerGuid)
if err != nil {
return fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
}
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))
roInitialize()

// Check to see if Hello is an option
isHelloSupported, err := isSupported(keyCredentialManager)
isHelloSupported, err := isSupported()
if err != nil {
return fmt.Errorf("determining whether Hello is supported: %w", err)
}
Expand All @@ -135,7 +125,7 @@ func Register(credentialName string) error {
}

// Create a credential that will be tied to the current user and this application
if err := register(keyCredentialManager, credentialName); err != nil {
if err := register(credentialName); err != nil {
return fmt.Errorf("creating credential: %w", err)
}

Expand All @@ -144,10 +134,20 @@ func Register(credentialName string) error {

// Detect prompts the user via Hello.
func Detect(_ string, credentialName string) (bool, error) {
if err := ole.RoInitialize(1); err != nil {
return false, fmt.Errorf("initializing: %w", err)
roInitialize()

// Create a credential that will be tied to the current user and this application
if err := authenticate(credentialName); err != nil {
return false, fmt.Errorf("authenticating with credential: %w", err)
}

return true, nil
}

// isSupported calls Windows.Security.Credentials.KeyCredentialManager.IsSupportedAsync.
// It determines whether the current device and user is capable of provisioning a key credential.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.issupportedasync?view=winrt-26100
func isSupported() (bool, error) {
// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
Expand All @@ -161,18 +161,6 @@ func Detect(_ string, credentialName string) (bool, error) {
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))

// Create a credential that will be tied to the current user and this application
if err := authenticate(keyCredentialManager, credentialName); err != nil {
return false, fmt.Errorf("authenticating with credential: %w", err)
}

return true, nil
}

// isSupported calls Windows.Security.Credentials.KeyCredentialManager.IsSupportedAsync.
// It determines whether the current device and user is capable of provisioning a key credential.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.issupportedasync?view=winrt-26100
func isSupported(keyCredentialManager *KeyCredentialManager) (bool, error) {
var isSupportedAsyncOperation *foundation.IAsyncOperation
ret, _, _ := syscall.SyscallN(
keyCredentialManager.VTable().IsSupportedAsync,
Expand Down Expand Up @@ -212,7 +200,20 @@ func isSupported(keyCredentialManager *KeyCredentialManager) (bool, error) {
// register calls Windows.Security.Credentials.KeyCredentialManager.RequestCreateAsync.
// It creates a new key credential for the current user and application.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.requestcreateasync?view=winrt-26100
func register(keyCredentialManager *KeyCredentialManager, credentialName string) error {
func register(credentialName string) error {

Check failure on line 203 in ee/presencedetection/presencedetection_windows.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

confusing-naming: Method 'register' differs only by capitalization to function 'Register' in the same source file (revive)

Check failure on line 203 in ee/presencedetection/presencedetection_windows.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

confusing-naming: Method 'register' differs only by capitalization to function 'Register' in the same source file (revive)
// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
}
defer factory.Release()
managerObj, err := factory.QueryInterface(keyCredentialManagerGuid)
if err != nil {
return fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
}
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))

credentialNameHString, err := ole.NewHString(credentialName)
if err != nil {
return fmt.Errorf("creating credential name hstring: %w", err)
Expand Down Expand Up @@ -285,11 +286,10 @@ func register(keyCredentialManager *KeyCredentialManager, credentialName string)

// For now, we retrieve but do not return/store the pubkey and attestation. In the future
// we may want to store these.
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))
if _, err := getPubkey(credential); err != nil {
if _, err := getPubkey(keyCredentialObj); err != nil {
return fmt.Errorf("getting pubkey from credential: %w", err)
}
if _, err := getAttestation(credential); err != nil {
if _, err := getAttestation(keyCredentialObj); err != nil {
return fmt.Errorf("getting attestation from credential: %w", err)
}

Expand All @@ -299,7 +299,20 @@ func register(keyCredentialManager *KeyCredentialManager, credentialName string)
// authenticate calls Windows.Security.Credentials.KeyCredentialManager.OpenAsync.
// It retrieves the key credential stored under `credentialName` for the given user and application.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredentialmanager.openasync?view=winrt-26100
func authenticate(keyCredentialManager *KeyCredentialManager, credentialName string) error {
func authenticate(credentialName string) error {
// Get access to the KeyCredentialManager
factory, err := ole.RoGetActivationFactory("Windows.Security.Credentials.KeyCredentialManager", ole.IID_IInspectable)
if err != nil {
return fmt.Errorf("getting activation factory for KeyCredentialManager: %w", err)
}
defer factory.Release()
managerObj, err := factory.QueryInterface(keyCredentialManagerGuid)
if err != nil {
return fmt.Errorf("getting KeyCredentialManager from factory: %w", err)
}
defer managerObj.Release()
keyCredentialManager := (*KeyCredentialManager)(unsafe.Pointer(managerObj))

credentialNameHString, err := ole.NewHString(credentialName)
if err != nil {
return fmt.Errorf("creating credential name hstring: %w", err)
Expand Down Expand Up @@ -371,11 +384,10 @@ func authenticate(keyCredentialManager *KeyCredentialManager, credentialName str

// For now, we retrieve but do not return/store the pubkey and attestation. In the future
// we may want to store these.
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))
if _, err := getPubkey(credential); err != nil {
if _, err := getPubkey(keyCredentialObj); err != nil {
return fmt.Errorf("getting pubkey from credential: %w", err)
}
if _, err := getAttestation(credential); err != nil {
if _, err := getAttestation(keyCredentialObj); err != nil {
return fmt.Errorf("getting attestation from credential: %w", err)
}

Expand All @@ -385,7 +397,9 @@ func authenticate(keyCredentialManager *KeyCredentialManager, credentialName str
// getPubkey calls Windows.Security.Credentials.KeyCredential.RetrievePubkey.
// It returns the pubkey for the given key credential.
// See https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.retrievepublickey?view=winrt-26100.
func getPubkey(credential *KeyCredential) ([]byte, error) {
func getPubkey(keyCredentialObj *ole.IDispatch) ([]byte, error) {
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))

var pubkeyBufferPointer unsafe.Pointer
retrievePubKeyReturn, _, _ := syscall.SyscallN(
credential.VTable().RetrievePublicKeyWithDefaultBlobType,
Expand Down Expand Up @@ -422,7 +436,9 @@ func getPubkey(credential *KeyCredential) ([]byte, error) {
// getAttestation calls Windows.Security.Credentials.KeyCredential.GetAttestationAsync.
// It gets an attestation for a key credential.
// See: https://learn.microsoft.com/en-us/uwp/api/windows.security.credentials.keycredential.getattestationasync?view=winrt-26100
func getAttestation(credential *KeyCredential) ([]byte, error) {
func getAttestation(keyCredentialObj *ole.IDispatch) ([]byte, error) {
credential := (*KeyCredential)(unsafe.Pointer(keyCredentialObj))

var getAttestationAsyncOperation *foundation.IAsyncOperation
getAttestationReturn, _, _ := syscall.SyscallN(
credential.VTable().GetAttestationAsync,
Expand Down

0 comments on commit 262fce2

Please sign in to comment.