5

Golang Mock 原理分析

 2 years ago
source link: http://yangxikun.github.io/golang/2021/06/19/golang-mock.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

在写单元测试时,通常需要对某些不容易构造或者不容易获取的对象进行 mock,那么在 Golang 中,我们可以 mock 哪些对象,又有哪些 mock 的方法呢,以及它们是如何实现的?本文将对几个 Golang 常见的开源库进行分析,了解其实现原理。

gomock 的实现

https://github.com/golang/mock 用于 interface 接口的 mock,需要先通过命令行工具 mockgen 生成 interface 的 mock 类型,通常会把命令用 //go:generate 写在代码中,比如:

package main

//go:generate mockgen -source=main.go -destination=foo_mock.go -package=main Foo

type Foo interface {
	Say(string, []string) string
}

func main() {
}

mockgen 会为 Foo 生成 MockFoo 类型的实现:

// MockFoo is a mock of Foo interface.
type MockFoo struct {
   ctrl     *gomock.Controller
   recorder *MockFooMockRecorder
}

// MockFooMockRecorder is the mock recorder for MockFoo.
type MockFooMockRecorder struct {
   mock *MockFoo
}

// NewMockFoo creates a new mock instance.
func NewMockFoo(ctrl *gomock.Controller) *MockFoo {
   mock := &MockFoo{ctrl: ctrl}
   mock.recorder = &MockFooMockRecorder{mock}
   return mock
}

MockFoo 的方法有:

  • EXPECT() *MockFooMockRecorder:返回 MockFoo.recorder 用于对 MockFoo 的方法进行 mock
  • Say(string, []string) string:实现 Foo 的接口,调用 MockFoo.ctrl.Call(MockFoo, “Say”),查找匹配的 gomock.Call 并执行,没找到会报错

MockFooMockRecorder 的方法有:

  • Say(arg0 string, arg1 []string) *gomock.Call:调用 gomock.Controller.RecordCallWithMethodType(MockFoo, "Say", reflect.TypeOf((*MockFoo)(nil).Say), arg0) 创建 gomock.Call

gomock.Call 用于表示期望的 mock 调用,根据方法的接收者、方法名、参数进行匹配,部分字段说明:

  • preReqs:依赖的 gomock.Call 必须执行至少 gomock.Call.minCalls 才能调用当前的 gomock.Call
  • minCalls: 最少调用次数
  • maxCalls:最多调用次数
  • numCalls:已调用次数
  • actions:每次调用,会顺序执行数组里的每个 action 函数,最后一个返回非 nil 值的结果作为这次调用的返回结果。newCall(...) *Call 时默认会插入一个返回零值的 action 函数。我们在使用时会调用 gomock.Call 的四个方法往 actions 追加:
    • DoAndReturn(f interface{}):执行函数 f,同时使用 f 的返回值
    • Do(f interface{}):执行函数 f,但返回 nil
    • Return(rets …interface{}):设置返回值
    • SetArg(n int, value interface{}):修改第 n 个参数的值,参数类型必须为 ptr、interface、slice
// Call represents an expected call to a mock.
type Call struct {
   t TestHelper // for triggering test failures on invalid call setup

   receiver   interface{}  // the receiver of the method call
   method     string       // the name of the method
   methodType reflect.Type // the type of the method
   args       []Matcher    // the args
   origin     string       // file and line number of call setup

   preReqs []*Call // prerequisite calls

   // Expectations
   minCalls, maxCalls int

   numCalls int // actual number made

   // actions are called when this Call is called. Each action gets the args and
   // can set the return values by returning a non-nil slice. Actions run in the
   // order they are created.
   actions []func([]interface{}) []interface{}
}

gomock.Controller 负责存储,匹配和调用 gomock.Call,字段说明:

  • T:通常是 *testing.T
  • mu:并发安全锁
  • expectedCalls:存储创建出来的 gomock.Call
    • callSet.expected:待调用的 gomock.Call
      • 根据方法的接收者和方法名对 gomock.Call 进行索引,参数匹配由 gomock.Call.matches 完成
    • callSet.exhausted:有两种情况的 gomock.Call 会存储在这里
      • 调用次数 >= gomock.Call.maxCalls
      • 某个 gomock.Call 被匹配到了,并且其依赖的调用次数都 >= gomock.Call.minCalls,那么这些依赖都会被放到 callSet.exhausted
  • finished:标记测试执行完了
type Controller struct {
   // T should only be called within a generated mock. It is not intended to
   // be used in user code and may be changed in future versions. T is the
   // TestReporter passed in when creating the Controller via NewController.
   // If the TestReporter does not implement a TestHelper it will be wrapped
   // with a nopTestHelper.
   T             TestHelper
   mu            sync.Mutex
   expectedCalls *callSet
   finished      bool
}

// callSet represents a set of expected calls, indexed by receiver and method
// name.
type callSet struct {
	// Calls that are still expected.
	expected map[callSetKey][]*Call
	// Calls that have been exhausted.
	exhausted map[callSetKey][]*Call
}

// callSetKey is the key in the maps in callSet
type callSetKey struct {
	receiver interface{}
	fname    string
}

一个具体的调用例子:

func TestMockFoo(t *testing.T) {
   ctrl := gomock.NewController(t)
   mockFoo := NewMockFoo(ctrl)
   mockFoo.EXPECT().
      Say("foo", []string{"a", "b"}). // 创建匹配 Say("foo", []string{"a", "b"}) 的 gomock.Call
      Do(func(arg string, arg2 []string) string { // 只执行,不会使用其返回值
         t.Logf("Do(%s, %v)", arg, arg2)
         return arg
      }).
      SetArg(1, []string{"c", "d"}). // 修改第 2 个参数
      DoAndReturn(func(arg string, arg2 []string) string { // 执行,并使用其返回值
         t.Logf("DoAndReturn(%s, %v)", arg, arg2)
         return arg
      }).
      Return("Return"). // 设置返回值为 "Return"
      AnyTimes()
   t.Log(mockFoo.Say("foo", []string{"a", "b"})) // 调用 Say,从 gomock.Controller.expectedCalls 中查找匹配的 gomock.Call,顺序调用 gomock.Call.actions
}

执行结果:

=== RUN   TestMockFoo
    main_test.go:15: Do(foo, [a b])
    main_test.go:20: DoAndReturn(foo, [c d])
    main_test.go:25: Return
--- PASS: TestMockFoo (0.00s)
PASS

gostub 的实现

https://github.com/prashantv/gostub 使用 reflect 包实现,只能 mock 变量。比如需要 mock 一个方法时,需要这么写:

bar := Bar{field: "field"}
say := bar.Say // 必须把方法复制给一个变量
stub = gostub.StubFunc(&say, "bar")
println(bar.Say("foo"), say("foo")) // fieldfoo bar
stub.Reset()
println(bar.Say("foo"), say("foo")) // fieldfoo fieldfoo

gostub.Stubs 字段说明:

  • stubs:以 reflect.ValueOf(varToStub) 作为 KEY,Value 保存原始的值
  • origEnv:存储原始的环境变量值,用于 mock 环境变量
// Stubs represents a set of stubbed variables that can be reset.
type Stubs struct {
   // stubs is a map from the variable pointer (being stubbed) to the original value.
   stubs   map[reflect.Value]reflect.Value
   origEnv map[string]envVal
}

gostub.Stubs 用于 mock 的方法只有如下 3 个:

  • Stub(varToStub interface{}, stubVal interface{}):修改 varToStub 的值为 stubVal
  • StubFunc(funcVarToStub interface{}, stubVal …interface{}):设置 funcVarToStub 的返回值为 stubVal…
  • SetEnv(k, v string):设置环境变量

StubFunc 的实现很简单,最后还是调用的 Stub:

  • funcVarToStub:必须为指针函数类型
  • FuncReturning:使用 reflect.MakeFunc 创建一个函数,返回 stubVal…
// StubFunc replaces a function variable with a function that returns stubVal.
// funcVarToStub must be a pointer to a function variable. If the function
// returns multiple values, then multiple values should be passed to stubFunc.
// The values must match be assignable to the return values' types.
func (s *Stubs) StubFunc(funcVarToStub interface{}, stubVal ...interface{}) *Stubs {
   funcPtrType := reflect.TypeOf(funcVarToStub)
   if funcPtrType.Kind() != reflect.Ptr ||
      funcPtrType.Elem().Kind() != reflect.Func {
      panic("func variable to stub must be a pointer to a function")
   }
   funcType := funcPtrType.Elem()
   if funcType.NumOut() != len(stubVal) {
      panic(fmt.Sprintf("func type has %v return values, but only %v stub values provided",
         funcType.NumOut(), len(stubVal)))
   }

   return s.Stub(funcVarToStub, FuncReturning(funcPtrType.Elem(), stubVal...).Interface())
}

Stub 的实现也不复杂:

// Stub replaces the value stored at varToStub with stubVal.
// varToStub must be a pointer to the variable. stubVal should have a type
// that is assignable to the variable.
func (s *Stubs) Stub(varToStub interface{}, stubVal interface{}) *Stubs {
   v := reflect.ValueOf(varToStub)
   stub := reflect.ValueOf(stubVal)

   // varToStub 必须为变量的指针
   if v.Type().Kind() != reflect.Ptr {
      panic("variable to stub is expected to be a pointer")
   }

   if _, ok := s.stubs[v]; !ok {
      // 存储 varToStub 原始值
      s.stubs[v] = reflect.ValueOf(v.Elem().Interface())
   }

   // 设置为新的值
   // *varToStub = stubVal
   v.Elem().Set(stub)
   return s
}

一个具体的调用例子:

type Bar struct {
   field string
}

func (b Bar) Say(arg string) string {
   return b.field + arg
}

func fn(i int) int {
   return i+1
}

func main() {
   foo := "foo"
   stub := gostub.Stub(&foo, "stubFoo")
   println(foo) // stubFoo
   stub.Reset()
   println(foo) // foo

   slice := []int{1, 2}
   stub = gostub.Stub(&slice, []int{3})
   fmt.Println(slice) // [3]
   stub.Reset()
   fmt.Println(slice) // [1 2]

   ff := fn
   stub = gostub.Stub(&ff, func(i int) int{return 0})
   println(ff(1)) // 0
   stub.Reset()
   println(ff(1)) // 2

   stub = gostub.StubFunc(&ff, 1)
   println(ff(1)) // 1
   stub.Reset()
   println(ff(1)) // 2

   bar := Bar{field: "field"}
   say := bar.Say
   stub = gostub.StubFunc(&say, "bar")
   println(bar.Say("foo"), say("foo")) // fieldfoo bar
   stub.Reset()
   println(bar.Say("foo"), say("foo")) //fieldfoo fieldfoo

   os.Setenv("GOSTUB_VAR", "value")
   stub = gostub.New()
   stub.SetEnv("GOSTUB_VAR", "stub_value")
   println(os.Getenv("GOSTUB_VAR")) // stub_value
   stub.Reset()
   println(os.Getenv("GOSTUB_VAR")) // value
}

gomonkey 的实现

https://github.com/agiledragon/gomonkey 对变量的 mock 实现原理跟 gostub 一样都是通过 reflect 包实现的。除了 mock 变量,gomonkey 还可以直接 mock 导出函数/方法、mock 代码所在包的非导出函数。

gomonkey 提供了如下 mock 方法:

  • ApplyGlobalVar(target, double interface{}):使用 reflect 包,将 target 的值修改为 double
  • ApplyFuncVar(target, double interface{}):检查 target 是否为指针类型,与 double 函数声明是否相同,最后调用 ApplyGlobalVar
  • ApplyFunc(target, double interface{}):修改 target 的机器指令,跳转到 double 执行
  • ApplyMethod(target reflect.Type, methodName string, double interface{}):修改 method 的机器指令,跳转到 double 执行
  • ApplyFuncSeq(target interface{}, outputs []OutputCell):修改 target 的机器指令,跳转到 gomonkey 生成的一个函数执行,每次调用会顺序从 outputs 取出一个值返回
  • ApplyMethodSeq(target reflect.Type, methodName string, outputs []OutputCell):修改 target 的机器指令,跳转到 gomonkey 生成的一个方法执行,每次调用会顺序从 outputs 取出一个值返回
  • ApplyFuncVarSeq(target interface{}, outputs []OutputCell):gomonkey 生成一个函数顺序返回 outputs 中的值,调用 ApplyGlobalVar

Apply*Seq 的实现

getDoubleFunc 会通过 reflect 包创建一个函数,每次调用会顺序返回 outputs 中的值。

func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
    // 判断返回值个数是否正确
   if funcType.NumOut() != len(outputs[0].Values) {
      panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double",
         funcType.NumOut(), len(outputs[0].Values)))
   }

    // 构造返回结果集
   slice := make([]Params, 0)
   for _, output := range outputs {
      t := 0
      if output.Times <= 1 { // 每个值至少被返回一次
         t = 1
      } else {
         t = output.Times
      }
      for j := 0; j < t; j++ { // 根据返回次数将值追加到结果集中
         slice = append(slice, output.Values)
      }
   }

   i := 0 // 调用次数统计
   len := len(slice)
   return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
      if i < len {
         i++
         return GetResultValues(funcType, slice[i-1]...)
      }
      panic("double seq is less than call seq")
   })
}

指令替换的实现

通过将函数开头的机器指令替换为无条件JMP指令,跳转到 mock 函数执行。要实现这个功能,需要分三步走:

1、获取函数的内存地址

以如下代码为例子:

//go:noinline
func bar() string {
   return "bar"
}

func main() {
   fmt.Println(bar) // 0x10a2e20
   println(unsafe.Pointer(reflect.ValueOf(bar).Pointer())) // 0x10a2e20
}

执行命令 go build -o main . && go tool objdump -s 'bar' main 查看 bar 函数的内存地址为:0x10a2e20,与程序的输出一致,也就是我们可以使用 reflect 包获取到函数在内存中的地址。

TEXT main.bar(SB) /Users/roketyyang/Work/mock/gomonkey/f/f.go
  f.go:11               0x10a2e20               488d05f2450200          LEAQ go.string.*+217(SB), AX    
  f.go:11               0x10a2e27               4889442408              MOVQ AX, 0x8(SP)                
  f.go:11               0x10a2e2c               48c744241003000000      MOVQ $0x3, 0x10(SP)             
  f.go:11               0x10a2e35               c3                      RET 

在 gomonkey 中替换指令的实现为:

func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
   this.check(target, double) // 类型检查
   if _, ok := this.originals[target]; ok {
      panic("patch has been existed")
   }

   this.valueHolders[double] = double // 因为 mock 函数通常是一个闭包,也就是个局部作用域的对象,为了防止 mock 函数被 GC 回收掉,需要增加引用
   // 替换 target 的机器指令,返回的 origin 是 target 会被覆盖的机器指令
   original := replace(*(*uintptr)(getPointer(target)), uintptr(getPointer(double)))
   // 保存 target 被覆盖的机器指令,用于恢复 target
   this.originals[target] = original
   return this
}

其中 *(*uintptr)(getPointer(target)) 为 target 的函数地址,等同于 target.Pointer(),getPointer 返回的是指向 target 函数的指针,其实现如下:

type funcValue struct {
	_ uintptr
	p unsafe.Pointer
}

func getPointer(v reflect.Value) unsafe.Pointer {
	return (*funcValue)(unsafe.Pointer(&v)).p
}

reflect.Value 的结构如下,getPointer 相当于直接拿到了未导出的属性 reflect.Value.ptr,这是指向 target 函数的指针,所以要拿到 target 的函数地址,还得进行一次解引用。通过 target.Pointer() 可以直接拿到 target 的函数地址是因为 reflect.Value.Pointer() 在返回的时候就对 reflect.Value.ptr做了一次解引用。

type Value struct {
   typ *rtype
   ptr unsafe.Pointer
   flag
}

以如下代码为例子,看下函数变量的值是怎么存储的:

package main

//go:noinline
func foo() string {
   return "foo"
}

func main() {
   funcVar := foo
   println(funcVar())
   funcVar2 := foo
   println(funcVar2())
}

查看汇编代码:go tool compile -S ff.go,可以看到两个函数变量的调用都是通过把符号 "".foo·f(SB) 所指向的内存值放到 AX 寄存器,然后执行 CALL 指令。而 "".foo·f(SB) 使用到的内存大小是 8 个字节,并且值为 "".foo+0,即函数 foo 的地址,而 reflect.Value.ptr 实际上是符号 "".foo·f(SB) 的地址。

"".main STEXT size=205 args=0x0 locals=0x28 funcid=0x0
		        ......
        0x0021 00033 (ff.go:10) MOVQ    "".foo·f(SB), AX ; funcVar()
        0x0028 00040 (ff.go:10) LEAQ    "".foo·f(SB), DX
        0x002f 00047 (ff.go:10) PCDATA  $1, $0
        0x002f 00047 (ff.go:10) CALL    AX
		        ......
        0x006f 00111 (ff.go:12) MOVQ    "".foo·f(SB), AX ; funcVar2()
        0x0076 00118 (ff.go:12) LEAQ    "".foo·f(SB), DX
        0x007d 00125 (ff.go:12) CALL    AX
		        ......
"".foo·f SRODATA dupok size=8
        0x0000 00 00 00 00 00 00 00 00                          ........
        rel 0+8 t=1 "".foo+0

2、生成跳转指令

gomonkey 替换指令的代码:

// target 目标函数地址
// double mock 函数的指针
func replace(target, double uintptr) []byte {
	code := buildJmpDirective(double) // 生成跳转到 mock 函数的机器指令
	bytes := entryAddress(target, len(code))
	original := make([]byte, len(bytes))
	copy(original, bytes)
	modifyBinary(target, code)
	return original
}

func buildJmpDirective(double uintptr) []byte {
    d0 := byte(double)
    d1 := byte(double >> 8)
    d2 := byte(double >> 16)
    d3 := byte(double >> 24)
    d4 := byte(double >> 32)
    d5 := byte(double >> 40)
    d6 := byte(double >> 48)
    d7 := byte(double >> 56)

    // 返回跳转的机器指令
    return []byte{
        0x48, 0xBA, d0, d1, d2, d3, d4, d5, d6, d7, // MOV rdx, double 将 mock 函数的指针值放到 rdx 中
        0xFF, 0x22,     // JMP [rdx] 因为rdx 中存储的是 mock 函数的指针,所以需要使用[],从内存中获得 mock 函数的地址,然后跳转
    }
}

buildJmpDirective 的实现是间接近转移,这里其实也可以用直接转移 JMP rdx,rdx 中直接放 mock 函数的地址,这样就不需要 getPointer 了。当然也可以通过计算 target 函数地址和 mock 函数地址之间的距离,使用偏移量进行转移。

3、修改函数开头的指令

replace(target, double uintptr) []byte 中,首先通过 bytes := entryAddress(target, len(code)) 拿到 target 函数开头 12 字节数据,放在 []byte 变量中。再执行 copy(original, bytes) 把这 12 字节数据保存下来,便于之后恢复用。最后执行 modifyBinary(target, code) 修改指令:

func modifyBinary(target uintptr, bytes []byte) {
    function := entryAddress(target, len(bytes))
    
    // 默认情况下,代码段的内存页是不可写的,需要调用 Mprotect 修改 target 所在页的权限为可写
    page := entryAddress(pageStart(target), syscall.Getpagesize())
    err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)
    if err != nil {
        panic(err)
    }
    // 替换 target 函数开头的指令为跳转指令
    copy(function, bytes)
    // 恢复为读和执行权限
    err = syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC)
    if err != nil {
        panic(err)
    }
}

super monkey 如何解决 mock 非导出函数/方法

gomonkey 依赖 reflect 包获取目标函数的地址,但如果我们在写测试代码的时候,依赖包中的非导出的函数/方法就没办法 mock 到了。上面分析了 gomonkey 的实现,可以知道最关键的是拿到目标函数的地址,如果有不通过 reflect 包就能拿到目标函数的地址的方法,那么问题就解决了。

https://github.com/cch123/supermonkey 的解决办法就是通过符号表获取到函数的内存地址,supermonkey 读取符号表:

func init() {
   content, _ := nm.Parse(os.Args[0])

   lines := strings.Split(content, "\n")
   for _, line := range lines {
      line := strings.TrimSpace(line)
      arr := strings.Split(line, " ")
      if len(arr) < 3 {
         continue
      }

      funcSymbol, addr := arr[2], arr[0]
      addrUint, _ := strconv.ParseUint(addr, 16, 64) // addrUint 为函数地址
      symbolTable[funcSymbol] = uintptr(addrUint) // funcSymbol 函数符号,比如 supermonkey/pkg.Foo.say
   }
}

使用例子,PatchByFullSymbolName 需要传入函数符号,可以通过 go tool nm -type supermonkey | grep 'Foo' 命令获取:

func main() {
   f := pkg.Foo{}
   patchGuard := sm.PatchByFullSymbolName("supermonkey/pkg.Foo.say", func() string {
      return "mock say"
   })
   println(f.Say())
   patchGuard.Unpatch()
}

supermonkey 只提供了 mock 函数的方法,通常是搭配 gomonkey 使用。

在 Golang 中 mock 代码的作用域可见的接口、变量、函数、方法,以及不可见的非导出函数、方法都可以进行 mock 。接口的 mock 建议使用 gomock,其他对象的 mock 可以使用 gomonkey + supermonkey 的组合。

参考文章:


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK