76

Gorm的使用心得和一些常用扩展(二)

 4 years ago
source link: https://www.tuicool.com/articles/7NnUVre
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

上一篇 文章,我分享了自己在新增和更新的场景下,自己使用gorm的一些心得和扩展。本文,我将分享一些在查询的方面的心得。

首先,我把查询按照涉及到的表的数量分为:

  • 单表查询
  • 多表查询

按照查询范围又可以分为:

  • 查询一个
  • 范围查询
    • 查询一组
    • 有序查询
    • 查询前几个
    • 分页查询

在日常使用中,单表查询占据了多半的场景,把这部分的代码按照查询范围做一些封装,可以大大减少冗余的代码。

单表查询

于是,我仿照gorm API的风格,做了如下的封装:

ps:以下例子均以假定已定义user对象

查询一个

func (dw *DBExtension) GetOne(result interface{}, query interface{}, args ...interface{}) (found bool, err error) {
	var (
		tableNameAble TableNameAble
		ok            bool
	)

	if tableNameAble, ok = query.(TableNameAble); !ok {
		if tableNameAble, ok = result.(TableNameAble); !ok {
			return false, errors.New("neither the query nor result implement TableNameAble")
		}
	}

	err = dw.Table(tableNameAble.TableName()).Where(query, args...).First(result).Error

	if err == gorm.ErrRecordNotFound {
		dw.logger.LogInfoc("mysql", fmt.Sprintf("record not found for query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
		return false, nil
	}

	if err != nil {
		dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
		return false, err
	}

	return true, nil
}
复制代码

这段值得说明的就是对查询不到数据时的处理,gorm是报了 gorm.ErrRecordNotFound 的error, 我是对这个错误做了特殊处理,用found这个boolean值表述这个特殊状态。

调用代码如下:

condition := User{Id:1}
result := User{}

if  found, err := dw.GetOne(&result, condition); !found {
	//not found
    if err != nil {
    	// has error
        return err
    }
    
}
复制代码

也可以这样写,更加灵活的指定的查询条件:

result := User{}

if  found, err := dw.GetOne(&result, "id = ?", 1); !found {
	//not found
    if err != nil {
    	// has error
        return err
    }
    
}
复制代码

两种写法执行的语句都是:

select * from test.user where id = 1
复制代码

范围查询

针对四种范国查询,我做了如下封装:

func (dw *DBExtension) GetList(result interface{}, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, "", 0, 0, query, args)
}

func (dw *DBExtension) GetOrderedList(result interface{}, order string, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, order, 0, 0, query, args)
}

func (dw *DBExtension) GetFirstNRecords(result interface{}, order string, limit int, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, order, limit, 0, query, args)
}

func (dw *DBExtension) GetPageRangeList(result interface{}, order string, limit, offset int, query interface{}, args ...interface{}) error {
	return dw.getListCore(result, order, limit, offset, query, args)
}

func (dw *DBExtension) getListCore(result interface{}, order string, limit, offset int, query interface{}, args []interface{}) error {
	var (
		tableNameAble TableNameAble
		ok            bool
	)

	if tableNameAble, ok = query.(TableNameAble); !ok {
		// type Result []*Item{}
		// result := &Result{}
		resultType := reflect.TypeOf(result)
		if resultType.Kind() != reflect.Ptr {
			return errors.New("result is not a pointer")
		}

		sliceType := resultType.Elem()
		if sliceType.Kind() != reflect.Slice {
			return errors.New("result doesn't point to a slice")
		}
		// *Item
		itemPtrType := sliceType.Elem()
		// Item
		itemType := itemPtrType.Elem()

		elemValue := reflect.New(itemType)
		elemValueType := reflect.TypeOf(elemValue)
		tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()

		if elemValueType.Implements(tableNameAbleType) {
			return errors.New("neither the query nor result implement TableNameAble")
		}

		tableNameAble = elemValue.Interface().(TableNameAble)
	}

	db := dw.Table(tableNameAble.TableName()).Where(query, args...)
	if len(order) != 0 {
		db = db.Order(order)
	}

	if offset > 0 {
		db = db.Offset(offset)
	}

	if limit > 0 {
		db = db.Limit(limit)
	}

	if err := db.Find(result).Error; err != nil {
		dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, query is %+v, args are %+v, order is %s, limit is %d", tableNameAble.TableName(), query, args, order, limit))
		return err
	}

	return nil
}
复制代码

为了减少冗余的代码,通用的逻辑写在 getListCore 函数里,里面用到了一些golang反射的知识。

但只要记得golang的反射和其它语言的反射最大的不同,是 golang的反射是基本值而不是类型的 ,一切就好理解了。

其中的一个小技巧是如何判断一个类型是否实现了某个接口,用到了指向nil的指针。

elemValue := reflect.New(itemType)
	elemValueType := reflect.TypeOf(elemValue)
	tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()

	if elemValueType.Implements(tableNameAbleType) {
		return errors.New("neither the query nor result implement TableNameAble")
	}
复制代码

关于具体的使用,就不再一一举例子了,熟悉gorm api的同学可以一眼看出。

多表查询

关于多表查询,因为不同场景很难抽取出不同,也就没有再做封装,但是我的经验是 优先多使用gorm的方法,而不是自己拼sql 。你想要做的gorm都可以实现。

这里,我偷个懒,贴出自己在项目中写的最复杂的一段代码,供各位看官娱乐。

一个复杂的例子

这段代码是从埋点数据的中间表,为了用通用的代码实现不同展示场景下的查询,代码设计的比较灵活,其中涉及了关联多表的查询,按查询条件动态过滤和聚合,还有分页查询的逻辑。

func buildCommonStatisticQuery(tableName, startDate, endDate string) *gorm.DB {
	query := models.DB().Table(tableName)

	if startDate == endDate || endDate == "" {
		query = query.Where("date = ?", startDate)
	} else {
		query = query.Where("date >= ? and date <= ?", startDate, endDate)
	}

	return query
}

func buildElementsStatisticQuery(startDate, endDate,  elemId string,  elemType int32) *gorm.DB {
	query := buildCommonStatisticQuery("spotanalysis.element_statistics", startDate, endDate)

	if elemId != "" && elemType != 0 {
		query = query.Where("element_id = ? and element_type = ?", elemId, elemType)
	}

	return query
}

func CountElementsStatistics(count *int32, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string ) error {
	query := buildElementsStatisticQuery(startDate, endDate,  elemId, elemType)

	query = whereInstAndApp(query, instId, appId)

	if len(groupFields) != 0 {
		query = query.Select(fmt.Sprintf("count(distinct(concat(%s)))", strings.Join(groupFields, ",")))
	} else {
		query = query.Select("count(id)")
	}

	query = query.Count(count)
	return query.Error
}


func GetElementsStatistics(result interface{}, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string, orderBy string, ascOrder bool, limit, offset int32) error {
	query := buildElementsStatisticQuery(startDate, endDate, elemId, elemType)
	if len(groupFields) != 0 {
		groupBy := strings.Join(groupFields, "`,`")
		groupBy = "`" + groupBy + "`"
		query = query.Group(groupBy)
		query = havingInstAndApp(query, instId, appId)

		sumFields := strings.Join([]string{
			"SUM(`element_statistics`.`mp_count`) AS `mp_count`",
			"SUM(`element_statistics`.`h5_count`) AS `h5_count`",
			"SUM(`element_statistics`.`total_count`) AS `total_count`",
			"SUM(`element_statistics`.`collection_count`) AS `collection_count`",
			"SUM(`element_statistics`.`mp_share_count`) AS `mp_share_count`",
			"SUM(`element_statistics`.`h5_share_count`) AS `h5_share_count`",
			"SUM(`element_statistics`.`poster_share_count`) AS `poster_share_count`",
			"SUM(`element_statistics`.`total_share_count`) AS `total_share_count`",
		}, ",")

		query = query.Select(groupBy + "," + sumFields)
	} else {
		query = whereInstAndApp(query, instId, appId)
	}

	query = getPagedList(query, orderBy, ascOrder, limit, offset)

	return query.Find(result).Error
}

func getPagedList(query *gorm.DB, orderBy string, ascOrder bool, limit , offset int32) *gorm.DB {
	if orderBy != "" {
		if ascOrder {
			orderBy += " asc"
		} else {
			orderBy += " desc"
		}
		query = query.Order(orderBy)
	}

	if offset != 0 {
		query = query.Offset(offset)
	}
	if limit != 0 {
		query = query.Limit(limit)
	}
	return query
}

func whereInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
	query = query.Where("inst_id = ?", instId)
	if appId != "" {
		query = query.Where("app_id = ?", appId)
	}
	return query
}

func havingInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
	query = query.Having("inst_id = ?", instId)
	if appId != "" {
		query = query.Having("app_id = ?", appId)
	}
	return query
}

复制代码

感谢各位看官耐心看完,如果本文对你有用,请点个赞~~~

如果能到代码仓库: Github:Ksloveyuan/gorm-ex 给个✩star✩, 楼主就更加感谢了!


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK