Skip to content

Commit

Permalink
Library documentation (#8)
Browse files Browse the repository at this point in the history
* code and readme docs update

* readme update
  • Loading branch information
tonyredondo authored Jun 15, 2020
1 parent 608c12d commit b295ad1
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 14 deletions.
179 changes: 178 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,182 @@
# go-mpatch
Go library for monkey patching

## Compatibility

Library inspired by the blog post: https://bou.ke/blog/monkey-patching-in-go/
- **Go version:** tested from `go1.7` to `go1.15-beta`
- **Architectures:** `x86`, `amd64`
- **Operating systems:** tested in `macos`, `linux` and `windows`.

## Features

- Can patch package functions, instance functions (by pointer or by value), and create new functions from scratch.

## Limitations

- Target functions could be inlined, making those functions unpatcheables. You can use `//go:noinline` directive or build with the `gcflags=-l`
to disable inlining at compiler level.

- Write permission to memory pages containing executable code is needed, some operating systems could restrict this access.

- Not thread safe.

## Usage

### Patching a func
```go
//go:noinline
func methodA() int { return 1 }

//go:noinline
func methodB() int { return 2 }

func TestPatcher(t *testing.T) {
patch, err := mpatch.PatchMethod(methodA, methodB)
if err != nil {
t.Fatal(err)
}
if methodA() != 2 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}
```

### Patching using `reflect.ValueOf`
```go
//go:noinline
func methodA() int { return 1 }

//go:noinline
func methodB() int { return 2 }

func TestPatcherUsingReflect(t *testing.T) {
reflectA := reflect.ValueOf(methodA)
patch, err := mPatch.PatchMethodByReflect(reflectA, methodB)
if err != nil {
t.Fatal(err)
}
if methodA() != 2 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}
```

### Patching creating a new func at runtime
```go
//go:noinline
func methodA() int { return 1 }

func TestPatcherUsingMakeFunc(t *testing.T) {
reflectA := reflect.ValueOf(methodA)
patch, err := PatchMethodWithMakeFunc(reflectA,
func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.ValueOf(42)}
})
if err != nil {
t.Fatal(err)
}
if methodA() != 42 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}
```

### Patching an instance func
```go
type myStruct struct {
}

//go:noinline
func (s *myStruct) Method() int {
return 1
}

func TestInstancePatcher(t *testing.T) {
mStruct := myStruct{}

var patch *Patch
var err error
patch, err = PatchInstanceMethodByName(reflect.TypeOf(mStruct), "Method", func(m *myStruct) int {
patch.Unpatch()
defer patch.Patch()
return 41 + m.Method()
})
if err != nil {
t.Fatal(err)
}

if mStruct.Method() != 42 {
t.Fatal("The patch did not work")
}
err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if mStruct.Method() != 1 {
t.Fatal("The unpatch did not work")
}
}
```

### Patching an instance func by Value
```go
type myStruct struct {
}

//go:noinline
func (s myStruct) ValueMethod() int {
return 1
}

func TestInstanceValuePatcher(t *testing.T) {
mStruct := myStruct{}

var patch *Patch
var err error
patch, err = PatchInstanceMethodByName(reflect.TypeOf(mStruct), "ValueMethod", func(m myStruct) int {
patch.Unpatch()
defer patch.Patch()
return 41 + m.Method()
})
if err != nil {
t.Fatal(err)
}

if mStruct.ValueMethod() != 42 {
t.Fatal("The patch did not work")
}
err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if mStruct.ValueMethod() != 1 {
t.Fatal("The unpatch did not work")
}
}
```

> Library inspired by the blog post: https://bou.ke/blog/monkey-patching-in-go/
23 changes: 16 additions & 7 deletions patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ var (
pageSize = syscall.Getpagesize()
)

// Patches a target func to redirect calls to "redirection" func. Both function must have same arguments and return types.
func PatchMethod(target, redirection interface{}) (*Patch, error) {
tValue := getValueFrom(target)
rValue := getValueFrom(redirection)
Expand All @@ -43,6 +44,8 @@ func PatchMethod(target, redirection interface{}) (*Patch, error) {
}
return patch, nil
}
// Patches an instance func by using two parameters, the target struct type and the method name inside that type,
//this func will be redirected to the "redirection" func. Note: The first parameter of the redirection func must be the object instance.
func PatchInstanceMethodByName(target reflect.Type, methodName string, redirection interface{}) (*Patch, error) {
method, ok := target.MethodByName(methodName)
if !ok && target.Kind() == reflect.Struct {
Expand All @@ -54,6 +57,8 @@ func PatchInstanceMethodByName(target reflect.Type, methodName string, redirecti
}
return PatchMethodByReflect(method.Func, redirection)
}
// Patches a target func by passing the reflect.ValueOf of the func. The target func will be redirected to the "redirection" func.
// Both function must have same arguments and return types.
func PatchMethodByReflect(target reflect.Value, redirection interface{}) (*Patch, error) {
tValue := &target
rValue := getValueFrom(redirection)
Expand All @@ -66,10 +71,11 @@ func PatchMethodByReflect(target reflect.Value, redirection interface{}) (*Patch
}
return patch, nil
}
// Patches a target func with a "redirection" function created at runtime by using "reflect.MakeFunc".
func PatchMethodWithMakeFunc(target reflect.Value, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
return PatchMethodByReflect(target, reflect.MakeFunc(target.Type(), fn))
}

// Patch the target func with the redirection func.
func (p *Patch) Patch() error {
if p == nil {
return errors.New("patch is nil")
Expand All @@ -82,6 +88,7 @@ func (p *Patch) Patch() error {
}
return nil
}
// Unpatch the target func and recover the original func.
func (p *Patch) Unpatch() error {
if p == nil {
return errors.New("patch is nil")
Expand All @@ -96,7 +103,7 @@ func isPatchable(target, redirection *reflect.Value) error {
if target.Type() != redirection.Type() {
return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type()))
}
if _, ok := patches[getSafeCodePointer(target)]; ok {
if _, ok := patches[getCodePointer(target)]; ok {
return errors.New("the target is already patched")
}
return nil
Expand All @@ -105,7 +112,7 @@ func isPatchable(target, redirection *reflect.Value) error {
func applyPatch(patch *Patch) error {
patchLock.Lock()
defer patchLock.Unlock()
tPointer := getSafeCodePointer(patch.target)
tPointer := getCodePointer(patch.target)
rPointer := getInternalPtrFromValue(patch.redirection)
rPointerJumpBytes, err := getJumpFuncBytes(rPointer)
if err != nil {
Expand All @@ -114,7 +121,7 @@ func applyPatch(patch *Patch) error {
tPointerBytes := getMemorySliceFromPointer(tPointer, len(rPointerJumpBytes))
targetBytes := make([]byte, len(tPointerBytes))
copy(targetBytes, tPointerBytes)
if err := copyDataToPtr(tPointer, rPointerJumpBytes); err != nil {
if err := writeDataToPointer(tPointer, rPointerJumpBytes); err != nil {
return err
}
patch.targetBytes = targetBytes
Expand All @@ -128,12 +135,12 @@ func applyUnpatch(patch *Patch) error {
if patch.targetBytes == nil || len(patch.targetBytes) == 0 {
return errors.New("the target is not patched")
}
tPointer := getSafeCodePointer(patch.target)
tPointer := getCodePointer(patch.target)
if _, ok := patches[tPointer]; !ok {
return errors.New("the target is not patched")
}
delete(patches, tPointer)
err := copyDataToPtr(tPointer, patch.targetBytes)
err := writeDataToPointer(tPointer, patch.targetBytes)
if err != nil {
return err
}
Expand All @@ -148,6 +155,7 @@ func getValueFrom(data interface{}) reflect.Value {
}
}

// Extracts a memory slice from a pointer
func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte {
return *(*[]byte)(unsafe.Pointer(&sliceHeader{
Data: p,
Expand All @@ -156,7 +164,8 @@ func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte {
}))
}

func getSafeCodePointer(value *reflect.Value) unsafe.Pointer {
// Gets the code pointer of a func
func getCodePointer(value *reflect.Value) unsafe.Pointer {
p := getInternalPtrFromValue(value)
if p != nil {
p = *(*unsafe.Pointer)(p)
Expand Down
11 changes: 7 additions & 4 deletions patcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func TestPatcher(t *testing.T) {
}

func TestPatcherUsingReflect(t *testing.T) {
patch, err := PatchMethodByReflect(reflect.ValueOf(methodA), methodB)
reflectA := reflect.ValueOf(methodA)
patch, err := PatchMethodByReflect(reflectA, methodB)
if err != nil {
t.Fatal(err)
}
Expand All @@ -63,9 +64,11 @@ func TestPatcherUsingReflect(t *testing.T) {
}

func TestPatcherUsingMakeFunc(t *testing.T) {
patch, err := PatchMethodWithMakeFunc(reflect.ValueOf(methodA), func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.ValueOf(42)}
})
reflectA := reflect.ValueOf(methodA)
patch, err := PatchMethodWithMakeFunc(reflectA,
func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.ValueOf(42)}
})
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion patcher_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func callMProtect(addr unsafe.Pointer, length int, prot int) error {
return nil
}

func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
func writeDataToPointer(ptr unsafe.Pointer, data []byte) error {
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
if err := callMProtect(ptr, dataLength, writeAccess); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion patcher_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func callVirtualProtect(lpAddress unsafe.Pointer, dwSize int, flNewProtect uint3
return nil
}

func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
func writeDataToPointer(ptr unsafe.Pointer, data []byte) error {
var oldPerms, tmp uint32
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
Expand Down

0 comments on commit b295ad1

Please sign in to comment.