From bfd2e1d1eb250b84e1976e5744f1e6004b1fb500 Mon Sep 17 00:00:00 2001 From: Daniel Redondo Date: Tue, 2 Jun 2020 22:05:20 +0200 Subject: [PATCH] Pointer safety (#6) * removes uintptr and uses unsafe.Pointer * experiment test for GC * Update README.md --- README.md | 1 + patcher.go | 36 +++++++++++++++++++---------------- patcher.s | 0 patcher_test.go | 43 ++++++++++++++++++++++++++++++++++++++++++ patcher_unix.go | 24 ++++++++++++++++++----- patcher_unsupported.go | 4 +++- patcher_windows.go | 6 +++--- patcher_x32.go | 13 ++++++++----- patcher_x64.go | 21 ++++++++++++--------- 9 files changed, 109 insertions(+), 39 deletions(-) create mode 100644 patcher.s diff --git a/README.md b/README.md index dbecad2..c5a85e2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # go-mpatch Go library for monkey patching + Library inspired by the blog post: https://bou.ke/blog/monkey-patching-in-go/ diff --git a/patcher.go b/patcher.go index ffe89ef..dc8a050 100644 --- a/patcher.go +++ b/patcher.go @@ -15,15 +15,19 @@ type ( target *reflect.Value redirection *reflect.Value } - pointer struct { - length uintptr - ptr uintptr + sliceHeader struct { + Data unsafe.Pointer + Len int + Cap int } ) +//go:linkname getInternalPtrFromValue reflect.(*Value).pointer +func getInternalPtrFromValue(v *reflect.Value) unsafe.Pointer + var ( patchLock = sync.Mutex{} - patches = make(map[uintptr]*Patch) + patches = make(map[unsafe.Pointer]*Patch) pageSize = syscall.Getpagesize() ) @@ -99,7 +103,7 @@ func isPatchable(target, redirection *reflect.Value) error { if target.Type() != redirection.Type() { return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type())) } - if _, ok := patches[target.Pointer()]; ok { + if _, ok := patches[getSafePointer(target)]; ok { return errors.New("the target is already patched") } return nil @@ -108,8 +112,8 @@ func isPatchable(target, redirection *reflect.Value) error { func applyPatch(patch *Patch) error { patchLock.Lock() defer patchLock.Unlock() - tPointer := patch.target.Pointer() - rPointer := getInternalPtrFromValue(*patch.redirection) + tPointer := getSafePointer(patch.target) + rPointer := getInternalPtrFromValue(patch.redirection) rPointerJumpBytes, err := getJumpFuncBytes(rPointer) if err != nil { return err @@ -132,7 +136,7 @@ func applyUnpatch(patch *Patch) error { if patch.targetBytes == nil || len(patch.targetBytes) == 0 { return errors.New("the target is not patched") } - tPointer := patch.target.Pointer() + tPointer := getSafePointer(patch.target) if _, ok := patches[tPointer]; !ok { return errors.New("the target is not patched") } @@ -144,10 +148,6 @@ func applyUnpatch(patch *Patch) error { return nil } -func getInternalPtrFromValue(value reflect.Value) uintptr { - return (*pointer)(unsafe.Pointer(&value)).ptr -} - func getValueFrom(data interface{}) reflect.Value { if cValue, ok := data.(reflect.Value); ok { return cValue @@ -156,14 +156,18 @@ func getValueFrom(data interface{}) reflect.Value { } } -func getMemorySliceFromPointer(p uintptr, length int) []byte { - return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ +func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte { + return *(*[]byte)(unsafe.Pointer(&sliceHeader{ Data: p, Len: length, Cap: length, })) } -func getPageStartPtr(ptr uintptr) uintptr { - return ptr & ^(uintptr(pageSize - 1)) +func getSafePointer(value *reflect.Value) unsafe.Pointer { + p := getInternalPtrFromValue(value) + if p != nil { + p = *(*unsafe.Pointer)(p) + } + return p } diff --git a/patcher.s b/patcher.s new file mode 100644 index 0000000..e69de29 diff --git a/patcher_test.go b/patcher_test.go index f34f24a..2ce9d1e 100644 --- a/patcher_test.go +++ b/patcher_test.go @@ -2,7 +2,9 @@ package mpatch import ( "reflect" + "runtime" "testing" + "time" ) //go:noinline @@ -93,3 +95,44 @@ func TestInstanceValuePatcher(t *testing.T) { t.Fatal("The unpatch did not work") } } + +var slice []int + +//go:noinline +func TestGarbageCollectorExperiment(t *testing.T) { + + for i := 0; i < 10000000; i++ { + slice = append(slice, i) + } + go func() { + var sl []int + for i := 0; i < 10000000; i++ { + sl = append(slice, i) + } + _ = sl + }() + <-time.After(time.Second) + + aVal := methodA + ptr01 := reflect.ValueOf(aVal).Pointer() + slice = nil + runtime.GC() + for i := 0; i < 10000000; i++ { + slice = append(slice, i) + } + go func() { + var sl []int + for i := 0; i < 10000000; i++ { + sl = append(slice, i) + } + _ = sl + }() + <-time.After(time.Second) + slice = nil + runtime.GC() + ptr02 := reflect.ValueOf(aVal).Pointer() + + if ptr01 != ptr02 { + t.Fail() + } +} diff --git a/patcher_unix.go b/patcher_unix.go index 82371bb..75fdc2a 100644 --- a/patcher_unix.go +++ b/patcher_unix.go @@ -2,14 +2,28 @@ package mpatch -import "syscall" +import ( + "reflect" + "syscall" + "unsafe" +) var writeAccess = syscall.PROT_READ | syscall.PROT_WRITE | syscall.PROT_EXEC var readAccess = syscall.PROT_READ | syscall.PROT_EXEC -func callMProtect(addr uintptr, length int, prot int) error { - for p := getPageStartPtr(addr); p < addr+uintptr(length); p += uintptr(pageSize) { - page := getMemorySliceFromPointer(p, pageSize) +//go:nosplit +func getMemorySliceFromUintptr(p uintptr, length int) []byte { + return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ + Data: p, + Len: length, + Cap: length, + })) +} + +//go:nosplit +func callMProtect(addr unsafe.Pointer, length int, prot int) error { + for p := uintptr(addr) & ^(uintptr(pageSize - 1)); p < uintptr(addr)+uintptr(length); p += uintptr(pageSize) { + page := getMemorySliceFromUintptr(p, pageSize) err := syscall.Mprotect(page, prot) if err != nil { return err @@ -18,7 +32,7 @@ func callMProtect(addr uintptr, length int, prot int) error { return nil } -func copyDataToPtr(ptr uintptr, data []byte) error { +func copyDataToPtr(ptr unsafe.Pointer, data []byte) error { dataLength := len(data) ptrByteSlice := getMemorySliceFromPointer(ptr, len(data)) err := callMProtect(ptr, dataLength, writeAccess) diff --git a/patcher_unsupported.go b/patcher_unsupported.go index a2fe4b7..2933271 100644 --- a/patcher_unsupported.go +++ b/patcher_unsupported.go @@ -7,9 +7,11 @@ import ( "errors" "fmt" "runtime" + "unsafe" ) // Gets the jump function rewrite bytes -func getJumpFuncBytes(to uintptr) ([]byte, error) { +//go:nosplit +func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) { return nil, errors.New(fmt.Sprintf("Unsupported architecture: %s", runtime.GOARCH)) } diff --git a/patcher_windows.go b/patcher_windows.go index 38f816a..736211f 100644 --- a/patcher_windows.go +++ b/patcher_windows.go @@ -11,15 +11,15 @@ const pageExecuteReadAndWrite = 0x40 var virtualProtectProc = syscall.NewLazyDLL("kernel32.dll").NewProc("VirtualProtect") -func callVirtualProtect(lpAddress uintptr, dwSize int, flNewProtect uint32, lpflOldProtect unsafe.Pointer) error { - ret, _, _ := virtualProtectProc.Call(lpAddress, uintptr(dwSize), uintptr(flNewProtect), uintptr(lpflOldProtect)) +func callVirtualProtect(lpAddress unsafe.Pointer, dwSize int, flNewProtect uint32, lpflOldProtect unsafe.Pointer) error { + ret, _, _ := virtualProtectProc.Call(uintptr(lpAddress), uintptr(dwSize), uintptr(flNewProtect), uintptr(lpflOldProtect)) if ret == 0 { return syscall.GetLastError() } return nil } -func copyDataToPtr(ptr uintptr, data []byte) error { +func copyDataToPtr(ptr unsafe.Pointer, data []byte) error { var oldPerms, tmp uint32 dataLength := len(data) ptrByteSlice := getMemorySliceFromPointer(ptr, len(data)) diff --git a/patcher_x32.go b/patcher_x32.go index ddb676f..7b61f48 100644 --- a/patcher_x32.go +++ b/patcher_x32.go @@ -2,14 +2,17 @@ package mpatch +import "unsafe" + // Gets the jump function rewrite bytes -func getJumpFuncBytes(to uintptr) ([]byte, error) { +//go:nosplit +func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) { return []byte{ 0xBA, - byte(to), - byte(to >> 8), - byte(to >> 16), - byte(to >> 24), + byte(uintptr(to)), + byte(uintptr(to) >> 8), + byte(uintptr(to) >> 16), + byte(uintptr(to) >> 24), 0xFF, 0x22, }, nil } diff --git a/patcher_x64.go b/patcher_x64.go index 392a954..99523e3 100644 --- a/patcher_x64.go +++ b/patcher_x64.go @@ -2,18 +2,21 @@ package mpatch +import "unsafe" + // Gets the jump function rewrite bytes -func getJumpFuncBytes(to uintptr) ([]byte, error) { +//go:nosplit +func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) { return []byte{ 0x48, 0xBA, - byte(to), - byte(to >> 8), - byte(to >> 16), - byte(to >> 24), - byte(to >> 32), - byte(to >> 40), - byte(to >> 48), - byte(to >> 56), + byte(uintptr(to)), + byte(uintptr(to) >> 8), + byte(uintptr(to) >> 16), + byte(uintptr(to) >> 24), + byte(uintptr(to) >> 32), + byte(uintptr(to) >> 40), + byte(uintptr(to) >> 48), + byte(uintptr(to) >> 56), 0xFF, 0x22, }, nil }