Skip to content

Commit

Permalink
Code style and refactoring changes (#7)
Browse files Browse the repository at this point in the history
* Code style and refactoring changes

* fmt changes

* add go1.10 in the test suite

* adding go1.9 in the test matrix

* adding go1.8 in the test matrix

* test matrix changes

* adding go1.7 in the test matrix

* try to add macos to the test matrix

* go1.6 in test matrix

* remove go1.6 from the test matrix due an error in go get for macos
  • Loading branch information
tonyredondo authored Jun 12, 2020
1 parent bfd2e1d commit 608c12d
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 37 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ jobs:
name: Test
strategy:
matrix:
go_version: [1.11, 1.12, 1.13, 1.14]
os: [ubuntu-latest, windows-latest]
go_version: [1.7, 1.8, 1.9, "1.10", 1.11, 1.12, 1.13, 1.14, 1.15]
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go_version }}
Expand Down
40 changes: 16 additions & 24 deletions patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ var (
func PatchMethod(target, redirection interface{}) (*Patch, error) {
tValue := getValueFrom(target)
rValue := getValueFrom(redirection)
err := isPatchable(&tValue, &rValue)
if err != nil {
if err := isPatchable(&tValue, &rValue); err != nil {
return nil, err
}
patch := &Patch{target: &tValue, redirection: &rValue}
err = applyPatch(patch)
if err != nil {
if err := applyPatch(patch); err != nil {
return nil, err
}
return patch, nil
Expand All @@ -54,37 +52,32 @@ func PatchInstanceMethodByName(target reflect.Type, methodName string, redirecti
if !ok {
return nil, errors.New(fmt.Sprintf("Method '%v' not found", methodName))
}
return PatchMethodByReflect(method, redirection)
return PatchMethodByReflect(method.Func, redirection)
}
func PatchMethodByReflect(target reflect.Method, redirection interface{}) (*Patch, error) {
tValue := &target.Func
func PatchMethodByReflect(target reflect.Value, redirection interface{}) (*Patch, error) {
tValue := &target
rValue := getValueFrom(redirection)
err := isPatchable(tValue, &rValue)
if err != nil {
if err := isPatchable(tValue, &rValue); err != nil {
return nil, err
}
patch := &Patch{target: tValue, redirection: &rValue}
err = applyPatch(patch)
if err != nil {
if err := applyPatch(patch); err != nil {
return nil, err
}
return patch, nil
}
func PatchMethodWithMakeFunc(target reflect.Method, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
rValue := reflect.MakeFunc(target.Type, fn)
return PatchMethodByReflect(target, rValue)
func PatchMethodWithMakeFunc(target reflect.Value, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
return PatchMethodByReflect(target, reflect.MakeFunc(target.Type(), fn))
}

func (p *Patch) Patch() error {
if p == nil {
return errors.New("patch is nil")
}
err := isPatchable(p.target, p.redirection)
if err != nil {
if err := isPatchable(p.target, p.redirection); err != nil {
return err
}
err = applyPatch(p)
if err != nil {
if err := applyPatch(p); err != nil {
return err
}
return nil
Expand All @@ -103,7 +96,7 @@ func isPatchable(target, redirection *reflect.Value) error {
if target.Type() != redirection.Type() {
return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type()))
}
if _, ok := patches[getSafePointer(target)]; ok {
if _, ok := patches[getSafeCodePointer(target)]; ok {
return errors.New("the target is already patched")
}
return nil
Expand All @@ -112,7 +105,7 @@ func isPatchable(target, redirection *reflect.Value) error {
func applyPatch(patch *Patch) error {
patchLock.Lock()
defer patchLock.Unlock()
tPointer := getSafePointer(patch.target)
tPointer := getSafeCodePointer(patch.target)
rPointer := getInternalPtrFromValue(patch.redirection)
rPointerJumpBytes, err := getJumpFuncBytes(rPointer)
if err != nil {
Expand All @@ -121,8 +114,7 @@ func applyPatch(patch *Patch) error {
tPointerBytes := getMemorySliceFromPointer(tPointer, len(rPointerJumpBytes))
targetBytes := make([]byte, len(tPointerBytes))
copy(targetBytes, tPointerBytes)
err = copyDataToPtr(tPointer, rPointerJumpBytes)
if err != nil {
if err := copyDataToPtr(tPointer, rPointerJumpBytes); err != nil {
return err
}
patch.targetBytes = targetBytes
Expand All @@ -136,7 +128,7 @@ func applyUnpatch(patch *Patch) error {
if patch.targetBytes == nil || len(patch.targetBytes) == 0 {
return errors.New("the target is not patched")
}
tPointer := getSafePointer(patch.target)
tPointer := getSafeCodePointer(patch.target)
if _, ok := patches[tPointer]; !ok {
return errors.New("the target is not patched")
}
Expand Down Expand Up @@ -164,7 +156,7 @@ func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte {
}))
}

func getSafePointer(value *reflect.Value) unsafe.Pointer {
func getSafeCodePointer(value *reflect.Value) unsafe.Pointer {
p := getInternalPtrFromValue(value)
if p != nil {
p = *(*unsafe.Pointer)(p)
Expand Down
38 changes: 38 additions & 0 deletions patcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,44 @@ func TestPatcher(t *testing.T) {
}
}

func TestPatcherUsingReflect(t *testing.T) {
patch, err := PatchMethodByReflect(reflect.ValueOf(methodA), methodB)
if err != nil {
t.Fatal(err)
}
if methodA() != 2 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}

func TestPatcherUsingMakeFunc(t *testing.T) {
patch, err := PatchMethodWithMakeFunc(reflect.ValueOf(methodA), func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.ValueOf(42)}
})
if err != nil {
t.Fatal(err)
}
if methodA() != 42 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}

func TestInstancePatcher(t *testing.T) {
mStruct := myStruct{}

Expand Down
9 changes: 3 additions & 6 deletions patcher_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func getMemorySliceFromUintptr(p uintptr, length int) []byte {
func callMProtect(addr unsafe.Pointer, length int, prot int) error {
for p := uintptr(addr) & ^(uintptr(pageSize - 1)); p < uintptr(addr)+uintptr(length); p += uintptr(pageSize) {
page := getMemorySliceFromUintptr(p, pageSize)
err := syscall.Mprotect(page, prot)
if err != nil {
if err := syscall.Mprotect(page, prot); err != nil {
return err
}
}
Expand All @@ -35,13 +34,11 @@ func callMProtect(addr unsafe.Pointer, length int, prot int) error {
func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
err := callMProtect(ptr, dataLength, writeAccess)
if err != nil {
if err := callMProtect(ptr, dataLength, writeAccess); err != nil {
return err
}
copy(ptrByteSlice, data[:])
err = callMProtect(ptr, dataLength, readAccess)
if err != nil {
if err := callMProtect(ptr, dataLength, readAccess); err != nil {
return err
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion patcher_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ import (
// Gets the jump function rewrite bytes
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
return nil, errors.New(fmt.Sprintf("Unsupported architecture: %s", runtime.GOARCH))
return nil, errors.New(fmt.Sprintf("unsupported architecture: %s", runtime.GOARCH))
}
6 changes: 2 additions & 4 deletions patcher_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
var oldPerms, tmp uint32
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
err := callVirtualProtect(ptr, dataLength, pageExecuteReadAndWrite, unsafe.Pointer(&oldPerms))
if err != nil {
if err := callVirtualProtect(ptr, dataLength, pageExecuteReadAndWrite, unsafe.Pointer(&oldPerms)); err != nil {
return err
}
copy(ptrByteSlice, data[:])
err = callVirtualProtect(ptr, dataLength, oldPerms, unsafe.Pointer(&tmp))
if err != nil {
if err := callVirtualProtect(ptr, dataLength, oldPerms, unsafe.Pointer(&tmp)); err != nil {
return err
}
return nil
Expand Down
2 changes: 2 additions & 0 deletions patcher_x32.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package mpatch

import "unsafe"

const jumpLength = 7

// Gets the jump function rewrite bytes
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
Expand Down
2 changes: 2 additions & 0 deletions patcher_x64.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package mpatch

import "unsafe"

const jumpLength = 12

// Gets the jump function rewrite bytes
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
Expand Down

0 comments on commit 608c12d

Please sign in to comment.