Skip to content

Commit

Permalink
Pointer safety (#6)
Browse files Browse the repository at this point in the history
* removes uintptr and uses unsafe.Pointer

* experiment test for GC

* Update README.md
  • Loading branch information
tonyredondo authored Jun 2, 2020
1 parent 872e88c commit bfd2e1d
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 39 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# go-mpatch
Go library for monkey patching


Library inspired by the blog post: https://bou.ke/blog/monkey-patching-in-go/
36 changes: 20 additions & 16 deletions patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ type (
target *reflect.Value
redirection *reflect.Value
}
pointer struct {
length uintptr
ptr uintptr
sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
)

//go:linkname getInternalPtrFromValue reflect.(*Value).pointer
func getInternalPtrFromValue(v *reflect.Value) unsafe.Pointer

var (
patchLock = sync.Mutex{}
patches = make(map[uintptr]*Patch)
patches = make(map[unsafe.Pointer]*Patch)
pageSize = syscall.Getpagesize()
)

Expand Down Expand Up @@ -99,7 +103,7 @@ func isPatchable(target, redirection *reflect.Value) error {
if target.Type() != redirection.Type() {
return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type()))
}
if _, ok := patches[target.Pointer()]; ok {
if _, ok := patches[getSafePointer(target)]; ok {
return errors.New("the target is already patched")
}
return nil
Expand All @@ -108,8 +112,8 @@ func isPatchable(target, redirection *reflect.Value) error {
func applyPatch(patch *Patch) error {
patchLock.Lock()
defer patchLock.Unlock()
tPointer := patch.target.Pointer()
rPointer := getInternalPtrFromValue(*patch.redirection)
tPointer := getSafePointer(patch.target)
rPointer := getInternalPtrFromValue(patch.redirection)
rPointerJumpBytes, err := getJumpFuncBytes(rPointer)
if err != nil {
return err
Expand All @@ -132,7 +136,7 @@ func applyUnpatch(patch *Patch) error {
if patch.targetBytes == nil || len(patch.targetBytes) == 0 {
return errors.New("the target is not patched")
}
tPointer := patch.target.Pointer()
tPointer := getSafePointer(patch.target)
if _, ok := patches[tPointer]; !ok {
return errors.New("the target is not patched")
}
Expand All @@ -144,10 +148,6 @@ func applyUnpatch(patch *Patch) error {
return nil
}

func getInternalPtrFromValue(value reflect.Value) uintptr {
return (*pointer)(unsafe.Pointer(&value)).ptr
}

func getValueFrom(data interface{}) reflect.Value {
if cValue, ok := data.(reflect.Value); ok {
return cValue
Expand All @@ -156,14 +156,18 @@ func getValueFrom(data interface{}) reflect.Value {
}
}

func getMemorySliceFromPointer(p uintptr, length int) []byte {
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte {
return *(*[]byte)(unsafe.Pointer(&sliceHeader{
Data: p,
Len: length,
Cap: length,
}))
}

func getPageStartPtr(ptr uintptr) uintptr {
return ptr & ^(uintptr(pageSize - 1))
func getSafePointer(value *reflect.Value) unsafe.Pointer {
p := getInternalPtrFromValue(value)
if p != nil {
p = *(*unsafe.Pointer)(p)
}
return p
}
Empty file added patcher.s
Empty file.
43 changes: 43 additions & 0 deletions patcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package mpatch

import (
"reflect"
"runtime"
"testing"
"time"
)

//go:noinline
Expand Down Expand Up @@ -93,3 +95,44 @@ func TestInstanceValuePatcher(t *testing.T) {
t.Fatal("The unpatch did not work")
}
}

var slice []int

//go:noinline
func TestGarbageCollectorExperiment(t *testing.T) {

for i := 0; i < 10000000; i++ {
slice = append(slice, i)
}
go func() {
var sl []int
for i := 0; i < 10000000; i++ {
sl = append(slice, i)
}
_ = sl
}()
<-time.After(time.Second)

aVal := methodA
ptr01 := reflect.ValueOf(aVal).Pointer()
slice = nil
runtime.GC()
for i := 0; i < 10000000; i++ {
slice = append(slice, i)
}
go func() {
var sl []int
for i := 0; i < 10000000; i++ {
sl = append(slice, i)
}
_ = sl
}()
<-time.After(time.Second)
slice = nil
runtime.GC()
ptr02 := reflect.ValueOf(aVal).Pointer()

if ptr01 != ptr02 {
t.Fail()
}
}
24 changes: 19 additions & 5 deletions patcher_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,28 @@

package mpatch

import "syscall"
import (
"reflect"
"syscall"
"unsafe"
)

var writeAccess = syscall.PROT_READ | syscall.PROT_WRITE | syscall.PROT_EXEC
var readAccess = syscall.PROT_READ | syscall.PROT_EXEC

func callMProtect(addr uintptr, length int, prot int) error {
for p := getPageStartPtr(addr); p < addr+uintptr(length); p += uintptr(pageSize) {
page := getMemorySliceFromPointer(p, pageSize)
//go:nosplit
func getMemorySliceFromUintptr(p uintptr, length int) []byte {
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
Data: p,
Len: length,
Cap: length,
}))
}

//go:nosplit
func callMProtect(addr unsafe.Pointer, length int, prot int) error {
for p := uintptr(addr) & ^(uintptr(pageSize - 1)); p < uintptr(addr)+uintptr(length); p += uintptr(pageSize) {
page := getMemorySliceFromUintptr(p, pageSize)
err := syscall.Mprotect(page, prot)
if err != nil {
return err
Expand All @@ -18,7 +32,7 @@ func callMProtect(addr uintptr, length int, prot int) error {
return nil
}

func copyDataToPtr(ptr uintptr, data []byte) error {
func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
err := callMProtect(ptr, dataLength, writeAccess)
Expand Down
4 changes: 3 additions & 1 deletion patcher_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"errors"
"fmt"
"runtime"
"unsafe"
)

// Gets the jump function rewrite bytes
func getJumpFuncBytes(to uintptr) ([]byte, error) {
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
return nil, errors.New(fmt.Sprintf("Unsupported architecture: %s", runtime.GOARCH))
}
6 changes: 3 additions & 3 deletions patcher_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ const pageExecuteReadAndWrite = 0x40

var virtualProtectProc = syscall.NewLazyDLL("kernel32.dll").NewProc("VirtualProtect")

func callVirtualProtect(lpAddress uintptr, dwSize int, flNewProtect uint32, lpflOldProtect unsafe.Pointer) error {
ret, _, _ := virtualProtectProc.Call(lpAddress, uintptr(dwSize), uintptr(flNewProtect), uintptr(lpflOldProtect))
func callVirtualProtect(lpAddress unsafe.Pointer, dwSize int, flNewProtect uint32, lpflOldProtect unsafe.Pointer) error {
ret, _, _ := virtualProtectProc.Call(uintptr(lpAddress), uintptr(dwSize), uintptr(flNewProtect), uintptr(lpflOldProtect))
if ret == 0 {
return syscall.GetLastError()
}
return nil
}

func copyDataToPtr(ptr uintptr, data []byte) error {
func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
var oldPerms, tmp uint32
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
Expand Down
13 changes: 8 additions & 5 deletions patcher_x32.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

package mpatch

import "unsafe"

// Gets the jump function rewrite bytes
func getJumpFuncBytes(to uintptr) ([]byte, error) {
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
return []byte{
0xBA,
byte(to),
byte(to >> 8),
byte(to >> 16),
byte(to >> 24),
byte(uintptr(to)),
byte(uintptr(to) >> 8),
byte(uintptr(to) >> 16),
byte(uintptr(to) >> 24),
0xFF, 0x22,
}, nil
}
21 changes: 12 additions & 9 deletions patcher_x64.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

package mpatch

import "unsafe"

// Gets the jump function rewrite bytes
func getJumpFuncBytes(to uintptr) ([]byte, error) {
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
return []byte{
0x48, 0xBA,
byte(to),
byte(to >> 8),
byte(to >> 16),
byte(to >> 24),
byte(to >> 32),
byte(to >> 40),
byte(to >> 48),
byte(to >> 56),
byte(uintptr(to)),
byte(uintptr(to) >> 8),
byte(uintptr(to) >> 16),
byte(uintptr(to) >> 24),
byte(uintptr(to) >> 32),
byte(uintptr(to) >> 40),
byte(uintptr(to) >> 48),
byte(uintptr(to) >> 56),
0xFF, 0x22,
}, nil
}

0 comments on commit bfd2e1d

Please sign in to comment.