49

03GORM源码解读

 4 years ago
source link: https://studygolang.com/articles/26083
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

简介

GORM 源码解读, 基于 v1.9.11 版本.

模型交互

前面已经研究过模型是如何定义并被解析的了, 这次看一下模型是如何和数据库交互的.

package main

import (
  "github.com/jinzhu/gorm"
  _ "github.com/jinzhu/gorm/dialects/sqlite"
)

type Product struct {
  gorm.Model
  Code string
  Price uint
}

func main() {
  db, err := gorm.Open("sqlite3", "test.db")
  if err != nil {
    panic("failed to connect database")
  }
  defer db.Close()

  // Migrate the schema
  db.AutoMigrate(&Product{})

  // 创建
  db.Create(&Product{Code: "L1212", Price: 1000})

  // 读取
  var product Product
  db.First(&product, 1) // 查询id为1的product
  db.First(&product, "code = ?", "L1212") // 查询code为l1212的product

  // 更新 - 更新product的price为2000
  db.Model(&product).Update("Price", 2000)

  // 删除 - 删除product
  db.Delete(&product)
}

AutoMigrate

当定义好模型之后, 第一步是使用 AutoMigrate 合并模型:

db.AutoMigrate(&Product{})

看一下它的源码:

// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB {
    db := s.Unscoped()
    for _, value := range values {
        db = db.NewScope(value).autoMigrate().db
    }
    return db
}

内部是对每个传递的参数调用了 db.NewScope(value).autoMigrate() .

那具体是如何合并的呢?

func (scope *Scope) autoMigrate() *Scope {
    tableName := scope.TableName()
    quotedTableName := scope.QuotedTableName()

    if !scope.Dialect().HasTable(tableName) {
        scope.createTable()
    } else {
        for _, field := range scope.GetModelStruct().StructFields {
            if !scope.Dialect().HasColumn(tableName, field.DBName) {
                if field.IsNormal {
                    sqlTag := scope.Dialect().DataTypeOf(field)
                    scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
                }
            }
            scope.createJoinTable(field)
        }
        scope.autoIndex()
    }
    return scope
}

中间的 if 部分的代码展示了两条路径. 如果表还没有创建, 直接创建就行了.

否则就需要对模型中的每个字段进行操作, 如果列名不存在, 就需要变更表新增字段了.

scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()

SQL 语句是如何执行的, 先暂时不理会, 但从代码的形式上看算是挺简洁的, 直接使用 Raw 构造语句, Exec 执行.

同时, 对于模型中的每个字段, 还要更新一遍连接表, scope.createJoinTable(field) .

在 for 循环处理完模型中的所有字段后, 再更新一遍索引, scope.autoIndex() .

总结起来, 自动合并主要做了这么几件事: 创建表, 添加新增的字段, 更新表的关系, 更新索引.

createTable

前面省略了创建表的具体过程, 来仔细看看表是如何创建的.

func (scope *Scope) createTable() *Scope {
    var tags []string
    var primaryKeys []string
    var primaryKeyInColumnType = false
    for _, field := range scope.GetModelStruct().StructFields {
        if field.IsNormal {
            sqlTag := scope.Dialect().DataTypeOf(field)

            // Check if the primary key constraint was specified as
            // part of the column type. If so, we can only support
            // one column as the primary key.
            if strings.Contains(strings.ToLower(sqlTag), "primary key") {
                primaryKeyInColumnType = true
            }

            tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
        }

        if field.IsPrimaryKey {
            primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
        }
        scope.createJoinTable(field)
    }

    var primaryKeyStr string
    if len(primaryKeys) > 0 && !primaryKeyInColumnType {
        primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
    }

    scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()

    scope.autoIndex()
    return scope
}

这就是构建 SQL 创建表的过程, 主要的过程是这行代码:

scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()

前面的过程主要是遍历模型的字段, 获取每个字段的 sqlTag , 并加入 tags 中:

tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)

带有双引号的列名加上空格加上 sqlTag .

这个过程中还涉及到了主键的判断, 不过感觉这部分有点坑, 因为

sqlTag := scope.Dialect().DataTypeOf(field) 的实现取决于每种数据库对 DataTypeOf 的具体实现.

issues 2270 显示出现多个 primary key ,

使用的是如下的模型定义, 数据库使用了 sqlite3:

type Permission struct {
    ID   int64  `gorm:"AUTO_INCREMENT;column:id;primary_key"`
    Name string `gorm:"column:name;type:varchar;unique;not null"`
    Idx  int64  `gorm:"AUTO_INCREMENT"`
}

虽然这个模型定义中只指定了一个 primary_key , 但结果 Idx 也变成了 primary_key :

[2019-01-19 19:40:30]  table "permission" has more than one primary key

[2019-01-19 19:40:30]  [0.14ms]  CREATE TABLE "permission" ("id" integer primary key autoincrement,"name" varchar NOT NULL UNIQUE,"idx" integer primary key autoincrement )
[0 rows affected or returned ]

原因只有一个, 它使用了 AUTO_INCREMENT 选项, 而在 sqlite3 的 DataTypeOf 实现中:

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
  if s.fieldCanAutoIncrement(field) {
    field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
    sqlType = "integer primary key autoincrement"
  } else {
    sqlType = "integer"
  }
case reflect.Int64, reflect.Uint64:
  if s.fieldCanAutoIncrement(field) {
    field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
    sqlType = "integer primary key autoincrement"
  } else {
    sqlType = "bigint"
  }

AUTO_INCREMENT 选项导致了返回的结果中存在 primary key .

我怀疑这是个 bug. 因为在后续有对是否是主键的判断, 并添加 primaryKeyStr .

if field.IsPrimaryKey {
  primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
var primaryKeyStr string
if len(primaryKeys) > 0 && !primaryKeyInColumnType {
  primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}

我觉得 sqlType 不应该返回关于 primary key 的信息.

要设置主键, 可以在后面的 primaryKeyStr 中进行.

好了, 对于主键的讨论就此告一段落了.

合并表和创建表的过程中都有 createJoinTable , 但因为关系实现还没有深入研究, 先忽略吧.

callbacks

增删改查都和 DB 结构体中的 callbacks 有关:

// DB contains information for current db connection
type DB struct {
  ...
    // global db
    parent        *DB
    callbacks     *Callback
    dialect       Dialect
    singularTable bool
  ...
}

看一下 Create 方法的代码:

// Create insert the value into database
func (s *DB) Create(value interface{}) *DB {
    scope := s.NewScope(value)
    return scope.callCallbacks(s.parent.callbacks.creates).db
}

在新的 scope 中调用了 callCallbacks 方法, 里面的参数是 s.parent.callbacks.creates .

parent 的类型也是 *DB , 算是继承.

继续挖掘 callCallbacks :

func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
    defer func() {
        if err := recover(); err != nil {
            if db, ok := scope.db.db.(sqlTx); ok {
                db.Rollback()
            }
            panic(err)
        }
    }()
    for _, f := range funcs {
        (*f)(scope)
        if scope.skipLeft {
            break
        }
    }
    return scope
}

使用了 defer 下的 recover 模式, 以前介绍过这个模式, 不再深入.

callCallbacks 的参数其实是个函数的切片, 然后依次调用所有的函数, 除非 scope.skipLeft 为 true.

看过了调用的方式, 让我们来看看 Callback 到底是什么.

// Callback is a struct that contains all CRUD callbacks
//   Field `creates` contains callbacks will be call when creating object
//   Field `updates` contains callbacks will be call when updating object
//   Field `deletes` contains callbacks will be call when deleting object
//   Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
//   Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
//   Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
    logger     logger
    creates    []*func(scope *Scope)
    updates    []*func(scope *Scope)
    deletes    []*func(scope *Scope)
    queries    []*func(scope *Scope)
    rowQueries []*func(scope *Scope)
    processors []*CallbackProcessor
}

Callback 里包含了很多的函数切片, 用于增删改查. 注释已经解释的很清楚了.

关注一下 CallbackProcessor , 这是用于按序生成所有 callbacks 的.

// CallbackProcessor contains callback informations
type CallbackProcessor struct {
    logger    logger
    name      string              // current callback's name
    before    string              // register current callback before a callback
    after     string              // register current callback after a callback
    replace   bool                // replace callbacks with same name
    remove    bool                // delete callbacks with same name
    kind      string              // callback type: create, update, delete, query, row_query
    processor *func(scope *Scope) // callback handler
    parent    *Callback
}
// Create could be used to register callbacks for creating object
//     db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
//       // business logic
//       ...
//
//       // set error if some thing wrong happened, will rollback the creating
//       scope.Err(errors.New("error"))
//     })
func (c *Callback) Create() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "create", parent: c}
}

// Update could be used to register callbacks for updating object, refer `Create` for usage
func (c *Callback) Update() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "update", parent: c}
}

// Delete could be used to register callbacks for deleting object, refer `Create` for usage
func (c *Callback) Delete() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c}
}

// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
// Refer `Create` for usage
func (c *Callback) Query() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "query", parent: c}
}

// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
    return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c}
}

Callback 有各种方法来创建不同类型的 CallbackProcessor .

// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
    cp.after = callbackName
    return cp
}

// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
    cp.before = callbackName
    return cp
}

AfterBefore 更新了 CallbackProcessor 上特定的属性, 用于后续计算 callback 调用顺序.

db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
  // business logic
  ...

  // set error if some thing wrong happened, will rollback the creating
  scope.Err(errors.New("error"))
})

注释上的例子是这样的, 继续看 Register 方法.

// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
    if cp.kind == "row_query" {
        if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
            cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
            cp.before = "gorm:row_query"
        }
    }

    cp.name = callbackName
    cp.processor = &callback
    cp.parent.processors = append(cp.parent.processors, cp)
    cp.parent.reorder()
}

主要是设置了 cp 的 processor 属性, 并将该 cp 添加到了 cp.parent.processors 中.

然后调用 cp.parent.reorder() 进行了重新排序.

有注册方法, 当然也有对应的删除方法:

// Remove a registered callback
//     db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
    cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
    cp.name = callbackName
    cp.remove = true
    cp.parent.processors = append(cp.parent.processors, cp)
    cp.parent.reorder()
}

设置 remove 属性为 true, 然后重新排序.

替换的方法也是类似:

// Replace a registered callback with new callback
//     db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
//           scope.SetColumn("Created", now)
//           scope.SetColumn("Updated", now)
//     })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
    cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
    cp.name = callbackName
    cp.processor = &callback
    cp.replace = true
    cp.parent.processors = append(cp.parent.processors, cp)
    cp.parent.reorder()
}

还是看一下重新排序是如何进行的吧:

// reorder all registered processors, and reset CRUD callbacks
func (c *Callback) reorder() {
    var creates, updates, deletes, queries, rowQueries []*CallbackProcessor

    for _, processor := range c.processors {
        if processor.name != "" {
            switch processor.kind {
            case "create":
                creates = append(creates, processor)
            case "update":
                updates = append(updates, processor)
            case "delete":
                deletes = append(deletes, processor)
            case "query":
                queries = append(queries, processor)
            case "row_query":
                rowQueries = append(rowQueries, processor)
            }
        }
    }

    c.creates = sortProcessors(creates)
    c.updates = sortProcessors(updates)
    c.deletes = sortProcessors(deletes)
    c.queries = sortProcessors(queries)
    c.rowQueries = sortProcessors(rowQueries)
}

上半部分只是分别归类, 具体还是要看 sortProcessors :

// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
    var (
        allNames, sortedNames []string
        sortCallbackProcessor func(c *CallbackProcessor)
    )

    for _, cp := range cps {
        // show warning message the callback name already exists
        if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
            cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
        }
        allNames = append(allNames, cp.name)
    }

    sortCallbackProcessor = func(c *CallbackProcessor) {
        if getRIndex(sortedNames, c.name) == -1 { // if not sorted
            if c.before != "" { // if defined before callback
                if index := getRIndex(sortedNames, c.before); index != -1 {
                    // if before callback already sorted, append current callback just after it
                    sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
                } else if index := getRIndex(allNames, c.before); index != -1 {
                    // if before callback exists but haven't sorted, append current callback to last
                    sortedNames = append(sortedNames, c.name)
                    sortCallbackProcessor(cps[index])
                }
            }

            if c.after != "" { // if defined after callback
                if index := getRIndex(sortedNames, c.after); index != -1 {
                    // if after callback already sorted, append current callback just before it
                    sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
                } else if index := getRIndex(allNames, c.after); index != -1 {
                    // if after callback exists but haven't sorted
                    cp := cps[index]
                    // set after callback's before callback to current callback
                    if cp.before == "" {
                        cp.before = c.name
                    }
                    sortCallbackProcessor(cp)
                }
            }

            // if current callback haven't been sorted, append it to last
            if getRIndex(sortedNames, c.name) == -1 {
                sortedNames = append(sortedNames, c.name)
            }
        }
    }

    for _, cp := range cps {
        sortCallbackProcessor(cp)
    }

    var sortedFuncs []*func(scope *Scope)
    for _, name := range sortedNames {
        if index := getRIndex(allNames, name); !cps[index].remove {
            sortedFuncs = append(sortedFuncs, cps[index].processor)
        }
    }

    return sortedFuncs
}

首先获取了所有 cp 的名字, 同时提示是否发现了重复. sortedNames 里保存排序好的名字.

// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
    for i := len(strs) - 1; i >= 0; i-- {
        if strs[i] == str {
            return i
        }
    }
    return -1
}

getRIndex 获取最右边的索引.

看一下 sortCallbackProcessor 函数到底在做什么.

里面有两个判断部分, 先看第一个部分:

if c.before != "" { // if defined before callback
  if index := getRIndex(sortedNames, c.before); index != -1 {
    // if before callback already sorted, append current callback just after it
    sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
  } else if index := getRIndex(allNames, c.before); index != -1 {
    // if before callback exists but haven't sorted, append current callback to last
    sortedNames = append(sortedNames, c.name)
    sortCallbackProcessor(cps[index])
  }
}

分为两种情况, 如果 before callback 已经排序好了, 直接插在它的后面就行.

如果 before callback 确实存在, 但还没有被排序, 就将当前名字直接放在 sortedNames 的最后.

然后递归调用 sortCallbackProcessor(cps[index]) , 这就是直接进入到 before callback 的排序中了.

再看第二个部分:

if c.after != "" { // if defined after callback
  if index := getRIndex(sortedNames, c.after); index != -1 {
    // if after callback already sorted, append current callback just before it
    sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
  } else if index := getRIndex(allNames, c.after); index != -1 {
    // if after callback exists but haven't sorted
    cp := cps[index]
    // set after callback's before callback to current callback
    if cp.before == "" {
      cp.before = c.name
    }
    sortCallbackProcessor(cp)
  }
}

其实和前面的逻辑差不多, 如果 after callback 已经排序好了, 直接插在它的前面就行.

如果 after callback 确实存在, 会修改 after callback 的 before 属性, 设置为当前 callback.

然后递归调用 sortCallbackProcessor(cp) , 进入到 after callback 的排序中.

// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 {
  sortedNames = append(sortedNames, c.name)
}

还没保存就直接放到最后. sortCallbackProcessor 的内容就是这样.

for _, cp := range cps {
  sortCallbackProcessor(cp)
}

开始排序. 等排序完了之后, sortedNames 就完成了:

var sortedFuncs []*func(scope *Scope)
for _, name := range sortedNames {
  if index := getRIndex(allNames, name); !cps[index].remove {
    sortedFuncs = append(sortedFuncs, cps[index].processor)
  }
}

return sortedFuncs

将那些不是 remove 状态的 callback, 依次添加到 sortedFuncs 中.

最后还有一个 Get 方法用于获取注册的回调:

// Get registered callback
//    db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
    for _, p := range cp.parent.processors {
        if p.name == callbackName && p.kind == cp.kind {
            if p.remove {
                callback = nil
            } else {
                callback = *p.processor
            }
        }
    }
    return
}

现在, 我们应该已经清楚了回调函数是如何注册并排序的了, 以及如何按名称获取单个回调函数.

实际注册流程

前面只是讲解了理论上的定义, 看一下实际上是在哪里注册的.

DB 在初始化的时候, 即 Open 方法调用了如下的语句:

db = &DB{
  db:        dbSQL,
  logger:    defaultLogger,
  callbacks: DefaultCallback,
  dialect:   newDialect(dialect, dbSQL),
}

这个 DefaultCallback 的定义如下:

// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}

一开始我也是有点慌, 这只是个空定义, 肯定有地方初始化的. 扫了一眼目录就明白了.

callback_create.go 文件下定义了 create 方面的注册流程.

// Define callbacks for creating
func init() {
    DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
    DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
    DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
    DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
    DefaultCallback.Create().Register("gorm:create", createCallback)
    DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
    DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
    DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
    DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}

结合 文档 ,

看一下 BeforeSaveBeforeCreate 是如何实现的.

当你定义一个模型时, 可以在这个模型上实现 BeforeSaveBeforeCreate 之类的方法,

这些方法会在恰当的时候被调用.

func (u *User) BeforeSave() (err error) {
  if !u.IsValid() {
    err = errors.New("can't save invalid data")
  }
  return
}

func (u *User) AfterCreate(scope *gorm.Scope) (err error) {
  if u.ID == 1 {
    scope.DB().Model(u).Update("role", "admin")
  }
  return
}

上面是官方文档上的例子. 在前面我们在注释中看到了如何手动注册一个回调函数,

类似于 DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) ,

但如何实现调用模型上定义的方法呢?

看一下 beforeCreateCallback 函数:

// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
    if !scope.HasError() {
        scope.CallMethod("BeforeSave")
    }
    if !scope.HasError() {
        scope.CallMethod("BeforeCreate")
    }
}

原来是通过 scope.CallMethod 方法实现的, 传递特定的方法名称就能调用该方法了.

// CallMethod call scope value's method, if it is a slice, will call its element's method one by one
func (scope *Scope) CallMethod(methodName string) {
    if scope.Value == nil {
        return
    }

    if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice {
        for i := 0; i < indirectScopeValue.Len(); i++ {
            scope.callMethod(methodName, indirectScopeValue.Index(i))
        }
    } else {
        scope.callMethod(methodName, indirectScopeValue)
    }
}

绕了一圈, 继续看 callMethod 的代码:

func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) {
    // Only get address from non-pointer
    if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr {
        reflectValue = reflectValue.Addr()
    }

    if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() {
        switch method := methodValue.Interface().(type) {
        case func():
            method()
        case func(*Scope):
            method(scope)
        case func(*DB):
            newDB := scope.NewDB()
            method(newDB)
            scope.Err(newDB.Error)
        case func() error:
            scope.Err(method())
        case func(*Scope) error:
            scope.Err(method(scope))
        case func(*DB) error:
            newDB := scope.NewDB()
            scope.Err(method(newDB))
            scope.Err(newDB.Error)
        default:
            scope.Err(fmt.Errorf("unsupported function %v", methodName))
        }
    }
}

这些灵活的方式都是靠反射实现的, 关键代码是 methodValue := reflectValue.MethodByName(methodName) .

switch 可以看到, 方法可以有不同的签名:

switch method := methodValue.Interface().(type) {
case func():
  method()
case func(*Scope):
  method(scope)
case func(*DB):
  newDB := scope.NewDB()
  method(newDB)
  scope.Err(newDB.Error)
case func() error:
  scope.Err(method())
case func(*Scope) error:
  scope.Err(method(scope))
case func(*DB) error:
  newDB := scope.NewDB()
  scope.Err(method(newDB))
  scope.Err(newDB.Error)
default:
  scope.Err(fmt.Errorf("unsupported function %v", methodName))
}

所以, 实际上这都可以看作是 reflect 的大型示范使用例子.

createCallback

其他的钩子函数不看了, 具体看一下当插入单条数据时都在干什么:

// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
    if !scope.HasError() {
        defer scope.trace(scope.db.nowFunc())

        var (
            columns, placeholders        []string
            blankColumnsWithDefaultValue []string
        )

        for _, field := range scope.Fields() {
            if scope.changeableField(field) {
                if field.IsNormal && !field.IsIgnored {
                    if field.IsBlank && field.HasDefaultValue {
                        blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
                        scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
                    } else if !field.IsPrimaryKey || !field.IsBlank {
                        columns = append(columns, scope.Quote(field.DBName))
                        placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
                    }
                } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
                    for _, foreignKey := range field.Relationship.ForeignDBNames {
                        if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
                            columns = append(columns, scope.Quote(foreignField.DBName))
                            placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
                        }
                    }
                }
            }
        }

        var (
            returningColumn = "*"
            quotedTableName = scope.QuotedTableName()
            primaryField    = scope.PrimaryField()
            extraOption     string
            insertModifier  string
        )

        if str, ok := scope.Get("gorm:insert_option"); ok {
            extraOption = fmt.Sprint(str)
        }
        if str, ok := scope.Get("gorm:insert_modifier"); ok {
            insertModifier = strings.ToUpper(fmt.Sprint(str))
            if insertModifier == "INTO" {
                insertModifier = ""
            }
        }

        if primaryField != nil {
            returningColumn = scope.Quote(primaryField.DBName)
        }

        lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)

        if len(columns) == 0 {
            scope.Raw(fmt.Sprintf(
                "INSERT %v INTO %v %v%v%v",
                addExtraSpaceIfExist(insertModifier),
                quotedTableName,
                scope.Dialect().DefaultValueStr(),
                addExtraSpaceIfExist(extraOption),
                addExtraSpaceIfExist(lastInsertIDReturningSuffix),
            ))
        } else {
            scope.Raw(fmt.Sprintf(
                "INSERT %v INTO %v (%v) VALUES (%v)%v%v",
                addExtraSpaceIfExist(insertModifier),
                scope.QuotedTableName(),
                strings.Join(columns, ","),
                strings.Join(placeholders, ","),
                addExtraSpaceIfExist(extraOption),
                addExtraSpaceIfExist(lastInsertIDReturningSuffix),
            ))
        }

        // execute create sql
        if lastInsertIDReturningSuffix == "" || primaryField == nil {
            if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
                // set rows affected count
                scope.db.RowsAffected, _ = result.RowsAffected()

                // set primary value to primary field
                if primaryField != nil && primaryField.IsBlank {
                    if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
                        scope.Err(primaryField.Set(primaryValue))
                    }
                }
            }
        } else {
            if primaryField.Field.CanAddr() {
                if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
                    primaryField.IsBlank = false
                    scope.db.RowsAffected = 1
                }
            } else {
                scope.Err(ErrUnaddressable)
            }
        }
    }
}

首先, 内部的第一个 for 循环遍历了所有的字段, 并更新了开头定义的三个切片.

for _, field := range scope.Fields() {
  if scope.changeableField(field) {
    if field.IsNormal && !field.IsIgnored {
      if field.IsBlank && field.HasDefaultValue {
        blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
        scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
      } else if !field.IsPrimaryKey || !field.IsBlank {
        columns = append(columns, scope.Quote(field.DBName))
        placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
      }
    } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
      for _, foreignKey := range field.Relationship.ForeignDBNames {
        if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
          columns = append(columns, scope.Quote(foreignField.DBName))
          placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
        }
      }
    }
  }
}

然后就是获取并设置一些信息:

var (
  returningColumn = "*"
  quotedTableName = scope.QuotedTableName()
  primaryField    = scope.PrimaryField()
  extraOption     string
  insertModifier  string
)

等信息都获取完了, 就开始构造插入语句了:

if len(columns) == 0 {
  scope.Raw(fmt.Sprintf(
    "INSERT %v INTO %v %v%v%v",
    addExtraSpaceIfExist(insertModifier),
    quotedTableName,
    scope.Dialect().DefaultValueStr(),
    addExtraSpaceIfExist(extraOption),
    addExtraSpaceIfExist(lastInsertIDReturningSuffix),
  ))
} else {
  scope.Raw(fmt.Sprintf(
    "INSERT %v INTO %v (%v) VALUES (%v)%v%v",
    addExtraSpaceIfExist(insertModifier),
    scope.QuotedTableName(),
    strings.Join(columns, ","),
    strings.Join(placeholders, ","),
    addExtraSpaceIfExist(extraOption),
    addExtraSpaceIfExist(lastInsertIDReturningSuffix),
  ))
}

最后执行 sql 语句:

// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
  if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
    // set rows affected count
    scope.db.RowsAffected, _ = result.RowsAffected()

    // set primary value to primary field
    if primaryField != nil && primaryField.IsBlank {
      if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
        scope.Err(primaryField.Set(primaryValue))
      }
    }
  }
} else {
  if primaryField.Field.CanAddr() {
    if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
      primaryField.IsBlank = false
      scope.db.RowsAffected = 1
    }
  } else {
    scope.Err(ErrUnaddressable)
  }
}

这里的第一个判断条件是和 lastInsertIDReturningSuffix 有关的, 只有 PostgreSQL 会返回非空的字符串.

var userid int
err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age)
    VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid)

PostgreSQL 中不支持 LastInsertId() 方法, 要获取 ID 需要像上面这样调用.

参考 PostgreSQL Queries .

所以执行方式有所不同.

这样, createCallback 回调就看完了, 插入数据的过程也知道了.

总结

在这一部分里, 主要看了数据表是如何创建和合并的, 以及钩子函数是如何注册并排序的, 以及何时被调用的.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK