29

Go+typescript+GraphQL+react构建简书网站(三) 编写Model

 4 years ago
source link: https://studygolang.com/articles/27018
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

补遗:数据库增加Tag表

新建tag表:

CREATE TABLE "public"."tag" (
  "id" int8 NOT NULL,
  "name" varchar(255) NOT NULL,
  "created_at" timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
  "updated_at" timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
  "deleted_at" timestamp(6) NOT NULL,
  PRIMARY KEY ("id")
)
;

COMMENT ON COLUMN "public"."tag"."id" IS 'ID';

COMMENT ON COLUMN "public"."tag"."name" IS '标签名';

COMMENT ON COLUMN "public"."tag"."created_at" IS '创建时间';

COMMENT ON COLUMN "public"."tag"."updated_at" IS '更新时间';

COMMENT ON COLUMN "public"."tag"."deleted_at" IS '删除时间';

这里不得不说一下,由于是一边写代码一边写文章(文章的作用只是用来给自己厘清思路),所以文章中的代码内容很可能下一次就变了,毕竟文章中的代码,只是我初步写时的思路,肯定存在错漏之处,后续会慢慢完善。如要看最新的代码,还请移步: https://github.com/unrotten/h...

编写CURD基础方法

依然先看结果,修改 db.go 文件:

package model

import (
    "context"
    "database/sql"
    "database/sql/driver"
    "fmt"
    "github.com/jmoiron/sqlx"
    _ "github.com/lib/pq"
    "github.com/rs/zerolog"
    "github.com/sony/sonyflake"
    "github.com/spf13/viper"
    "github.com/unrotten/builder"
    "github.com/unrotten/sqlex"
    "log"
    "os"
    "reflect"
    "time"
)

var (
    DB        *sqlx.DB
    psql      sqlex.StatementBuilderType
    idfetcher *sonyflake.Sonyflake
)

const defaultSkip int = 2

type cv map[string]interface{}

type where []sqlex.Sqlex

type result struct {
    b       builder.Builder
    success bool
}

// 初始化数据库连接
func init() {
    viper.AddConfigPath("../config") // 测试使用
    viper.ReadInConfig()
    // 获取数据库配置信息
    user := viper.Get("storage.user")
    password := viper.Get("storage.password")
    host := viper.Get("storage.host")
    port := viper.Get("storage.port")
    dbname := viper.Get("storage.dbname")

    // 连接数据库
    psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
        host, port, user, password, dbname)
    DB = sqlx.MustOpen("postgres", psqlInfo)
    if err := DB.Ping(); err != nil {
        log.Fatalf("连接数据库失败:%s", err)
    }

    // 初始化sql构建器,指定format形式
    psql = sqlex.StatementBuilder.PlaceholderFormat(sqlex.Dollar)
    sqlex.SetLogger(os.Stdout)

    // 初始化sonyflake
    st := sonyflake.Settings{
        StartTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.Local),
    }
    idfetcher = sonyflake.NewSonyflake(st)
}

func get(query *sql.Rows, columnTypes []*sql.ColumnType, logger zerolog.Logger) result {
    dest := make([]interface{}, len(columnTypes))
    for index, col := range columnTypes {
        switch col.ScanType().String() {
        case "string", "interface {}":
            dest[index] = &sql.NullString{}
        case "bool":
            dest[index] = &sql.NullBool{}
        case "float64":
            dest[index] = &sql.NullFloat64{}
        case "int32":
            dest[index] = &sql.NullInt32{}
        case "int64":
            dest[index] = &sql.NullInt64{}
        case "time.Time":
            dest[index] = &sql.NullTime{}
        default:
            dest[index] = reflect.New(col.ScanType()).Interface()
        }
    }
    err := query.Scan(dest...)
    if err != nil {
        logger.Error().Caller(2).Err(err).Send()
        return result{success: false}
    }
    build := builder.EmptyBuilder
    for index, col := range columnTypes {
        switch val := dest[index].(type) {
        case driver.Valuer:
            var value interface{}
            switch col.ScanType().String() {
            case "string", "interface {}":
                value = dest[index].(*sql.NullString).String
            case "bool":
                value = dest[index].(*sql.NullBool).Bool
            case "float64":
                value = dest[index].(*sql.NullFloat64).Float64
            case "int32":
                value = dest[index].(*sql.NullInt32).Int32
            case "int64":
                value = dest[index].(*sql.NullInt64).Int64
            case "time.Time":
                value = dest[index].(*sql.NullTime).Time
            }
            build = builder.Set(build, col.Name(), value).(builder.Builder)
        default:
            build = builder.Set(build, col.Name(), val).(builder.Builder)
        }
    }
    return result{success: true, b: build}
}

func selectList(ctx context.Context, table string, where where, columns ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    var selectBuilder sqlex.SelectBuilder
    if len(columns) > 0 {
        selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is null")
    } else {
        selectBuilder = psql.Select("*").From(table).Where("deleted_at is null")
    }
    for _, arg := range where {
        selectBuilder = selectBuilder.Where(arg)
    }
    query, err := selectBuilder.RunWith(tx).Query()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    columnTypes, err := query.ColumnTypes()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }
    var resultSlice []interface{}
    for query.Next() {
        r := get(query, columnTypes, logger)
        if !r.success {
            return r
        }
        resultSlice = append(resultSlice, r.b)
    }
    return result{success: true, b: builder.Set(builder.EmptyBuilder, "list", resultSlice).(builder.Builder)}
}

func selectOne(ctx context.Context, table string, where where, columns ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    var selectBuilder sqlex.SelectBuilder
    if len(columns) > 0 {
        selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is null").Limit(1)
    } else {
        selectBuilder = psql.Select("*").From(table).Where("deleted_at is null").Limit(1)
    }
    for _, arg := range where {
        selectBuilder = selectBuilder.Where(arg)
    }
    query, err := selectBuilder.RunWith(tx).Query()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    columnTypes, err := query.ColumnTypes()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    if query.Next() {
        return get(query, columnTypes, logger)
    }
    return result{success: false}
}

func selectReal(ctx context.Context, table string, where where, columns ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    var selectBuilder sqlex.SelectBuilder
    if len(columns) > 0 {
        selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is not null")
    } else {
        selectBuilder = psql.Select("*").From(table).Where("deleted_at is not null")
    }
    for _, arg := range where {
        selectBuilder = selectBuilder.Where(arg)
    }
    query, err := selectBuilder.RunWith(tx).Query()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }

    columnTypes, err := query.ColumnTypes()
    if err != nil {
        logger.Error().Caller(1).Err(err).Send()
        return result{success: false}
    }
    var resultSlice []interface{}
    for query.Next() {
        r := get(query, columnTypes, logger)
        if !r.success {
            return r
        }
        resultSlice = append(resultSlice, r.b)
    }
    return result{success: true, b: builder.Set(builder.EmptyBuilder, "list", resultSlice).(builder.Builder)}
}

func insertOne(ctx context.Context, table string, cv cv) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)
    build := builder.EmptyBuilder
    cv["created_at"], cv["updated_at"] = time.Now(), time.Now()
    columns, values := make([]string, 0, len(cv)), make([]interface{}, 0, len(cv))
    for col, value := range cv {
        build = builder.Set(build, col, value).(builder.Builder)
        columns, values = append(columns, col), append(values, value)
    }
    r, err := psql.Insert(table).Columns(columns...).Values(values...).RunWith(tx).Exec()
    return assertSqlResult(r, err, logger)
}

func update(ctx context.Context, table string, cv cv, where where, directSet ...string) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)
    cv["updated_at"] = time.Now()
    updateBuilder := psql.Update(table).SetMap(cv).Where("deleted_at is null")
    for _, set := range directSet {
        updateBuilder = updateBuilder.DirectSet(set)
    }
    for _, arg := range where {
        updateBuilder = updateBuilder.Where(arg)
    }
    r, err := updateBuilder.RunWith(tx).Exec()
    return assertSqlResult(r, err, logger)
}

// note: if where is null,then will delete the whole table
func remove(ctx context.Context, table string, where where) result {
    logger := ctx.Value("logger").(zerolog.Logger)
    tx := ctx.Value("tx").(*sqlx.Tx)

    updateBuilder := psql.Update(table).Set("deleted_at", time.Now()).Where("deleted_at is null")
    for _, arg := range where {
        updateBuilder = updateBuilder.Where(arg)
    }
    r, err := updateBuilder.RunWith(tx).Exec()
    return assertSqlResult(r, err, logger)
}

func assertSqlResult(r sql.Result, err error, logger zerolog.Logger, skip ...int) result {
    sk := defaultSkip
    if len(skip) > 0 {
        sk += skip[0]
    }
    if err != nil {
        logger.Error().Caller(sk).Err(err).Send()
        return result{success: false}
    }
    affected, err := r.RowsAffected()
    if err != nil {
        logger.Error().Caller(2).Err(err).Send()
        return result{success: false}
    }
    if affected == 0 {
        return result{success: false}
    }
    return result{success: true}
}

在这里我们只看查询,selectList和selectOne依托于get方法实现,而get的核心就是设值。因为在数据库中,数据存在NULL的情况,而Go中的基础类型如string,int64等并不支持,所以我们必须使用其对应的sql.NullString等类型去scan。作者这里为了保持model中定义的struct能够继续使用string等基础类型,在get中进行了类型的判断,不可空的基础类型通过两次switch转换,最终即便对于NULL值,也会得到基础类型的默认空值。

在get方法中,我们使用 reflect.New(col.ScanType()).Interface() 方法,获得字段对应的指针值,这里使用了反射,效果等同于new()。

在记录错误日志 logger.Error().Caller(sk).Err(err).Send() 时,我们先指定了日志的类别为Error,再调用了Caller(sk),获取运行时上下文。Caller的原理是调用 runtime.Caller(skip) 方法,以获取指定的代码段位置。最终效果就是通常我们程序报错时,在控制台能够看到的,各个文件的指定行。

在get方法的最后,我们通过 builder.Set(build, col.Name(), value).(builder.Builder) 这样的代码段,将数据对应的名字和值存入指定的builer中。builder的效果类似于map,只是使用builder库可以更方便直接将map转为指定的struct。

再把目光转到selectOne方法,可以看到我们从上下文context中获取了logger和事务tx,这里是方便后续的工作。我们需要注意的是,sqlex库进行sql构建时,严格按照了sql语法的规定,当然where和from之间的顺序在这里可以不用管。我们在初始化selectBuilder的时候, Where("1=1") 给定了一个初始的where条件,这样做的用意是,由于sqlex库提供了IF操作,譬如:

psql.Select("*").From("user").Where(sqlex.IF{Condition: "a" == "", Sq: sqlex.Eq{"a": "3"}})

这样的代码,由于 “a”==“” 不满足,所以IF中的 ”a”==“3” 并不会被纳入构建器中,可是也因为调用了Where,所以构建器中sql中必然会增加一个where,最终得到错误的 sql:SELECT * FROM "user" WHERE

编写Model

model 目录下新建 user.go 文件:

package model

import (
    "context"
    "errors"
    "github.com/unrotten/builder"
    "time"
)

type User struct {
    Id        int64     `json:"id" db:"id"`
    Username  string    `json:"username" db:"username"`
    Email     string    `json:"email" db:"email"`
    Password  string    `json:"password" db:"password"`
    Avatar    string    `json:"avatar" db:"avatar"`
    Gender    string    `json:"gender" db:"gender"`
    Introduce string    `json:"introduce" db:"introduce"`
    State     string    `json:"state" db:"state"`
    Root      bool      `json:"root" db:"root"`
    CreatedAt time.Time `json:"createdAt" db:"created_at"`
    UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
    DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}

func GetUsers(ctx context.Context, where where) ([]User, error) {
    result := selectList(ctx, `"user"`, where)
    if !result.success {
        return nil, errors.New("获取用户列表失败")
    }
    list, ok := builder.Get(result.b, "list")
    if !ok {
        return nil, errors.New("获取用户列表失败")
    }
    users := make([]User, 0, len(list.([]interface{})))
    for _, item := range list.([]interface{}) {
        users = append(users, builder.GetStructLikeByTag(item.(builder.Builder), User{}, "db").(User))
    }
    return users, nil
}

func GetUser(ctx context.Context, where where) (User, error) {
    result := selectOne(ctx, `"user"`, where)
    if !result.success {
        return User{}, errors.New("查询用户数据失败")
    }
    return builder.GetStructLikeByTag(result.b, User{}, "db").(User), nil
}

func InsertUser(ctx context.Context, cv map[string]interface{}) (User, error) {
    id, err := idfetcher.NextID()
    if err != nil {
        return User{}, err
    }

    cv["id"] = int64(id)
    result := insertOne(ctx, `"user"`, cv)
    if !result.success {
        return User{}, errors.New("插入用户数据失败")
    }
    return builder.GetStructLikeByTag(result.b, User{}, "db").(User), nil
}

func UpdateUser(ctx context.Context, cv cv, where where) error {
    result := update(ctx, `"user"`, cv, where)
    if !result.success {
        return errors.New("更新用户数据失败")
    }
    return nil
}

这里唯一需要注意的是,我们使用 builder.GetStructLikeByTag(result.b, User{}, "db").(User) 方法,将CURD中获得的Builder根据指定的tag内容,转化为对应结构体。

接下来,就是继续完善其他的model。

userCount.go :

package model

import (
    "context"
    "errors"
    "github.com/unrotten/builder"
    "github.com/unrotten/sqlex"
    "time"
)

type UserCount struct {
    Uid        int64     `json:"uid" db:"uid"`
    FansNum    int32     `json:"fansNum" db:"fans_num"`
    FollowNum  int32     `json:"followNum" db:"follow_num"`
    ArticleNum int32     `json:"articleNum" db:"article_num"`
    Words      int32     `json:"words" db:"words"`
    ZanNum     int32     `json:"zanNum" db:"zan_num"`
    CreatedAt  time.Time `json:"createdAt" db:"created_at"`
    UpdatedAt  time.Time `json:"updatedAt" db:"updated_at"`
    DeletedAt  time.Time `json:"deletedAt" db:"deleted_at"`
}

func GetUserCount(ctx context.Context, uid int64, columns ...string) (UserCount, error) {
    result := selectOne(ctx, "user_count", append(where{}, sqlex.Eq{"uid": uid}), columns...)
    if !result.success {
        return UserCount{}, errors.New("查询用户计数失败")
    }
    return builder.GetStructLikeByTag(result.b, UserCount{}, "db").(UserCount), nil
}

func InsertUserCount(ctx context.Context, uid int64) error {
    result := insertOne(ctx, "user_count", cv{"uid": uid})
    if !result.success {
        return errors.New("保存用户计数表失败")
    }
    return nil
}

func UpdateUserCount(ctx context.Context, uid int64, add bool, columns ...string) error {
    directSets, directSet := make([]string, 0, len(columns)), " + 1"
    if !add {
        directSet = " - 1"
    }
    for _, col := range columns {
        directSets = append(directSets, col+directSet)
    }
    if !update(ctx, "user_count", cv{}, where{sqlex.Eq{"uid": uid}}, directSets...).success {
        return errors.New("增加用户计数失败")
    }
    return nil
}

我们为了改变userCount中的计数值,定义了方法UpdateUserCount。可以通过指定加减和相应字段来实现计数值的加减。我们可以注意到了,这里在调用update的时候,传入了directSets,最终将通过update中的:

for _, set := range directSet {
        updateBuilder = updateBuilder.DirectSet(set)
}

将设置好的值构建到SQL中。DirectSet目的是构建无参数的set语句,所以并不建议暴露给从接口传入的参数,否则会有SQL注入的风险。

userFollow.go

package model

import (
    "context"
    "errors"
    "github.com/unrotten/builder"
    "github.com/unrotten/sqlex"
    "time"
)

type UserFollow struct {
    Id        int64     `json:"id" db:"id"`
    Uid       int64     `json:"uid" db:"uid"`
    Fuid      int64     `json:"fuid" db:"fuid"`
    CreatedAt time.Time `json:"createdAt" db:"created_at"`
    UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
    DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}

func InsertUserFollow(ctx context.Context, uid, fuid int64) error {
    id, err := idfetcher.NextID()
    if err != nil {
        return err
    }
    if result := insertOne(ctx, "user_follow", cv{"id": int64(id), "uid": uid, "fuid": fuid}); !result.success {
        return errors.New("插入用户关注表失败")
    }
    return nil
}

func RemoveUserFollow(ctx context.Context, uid, fuid int64) error {
    if !remove(ctx, "user_follow", where{sqlex.Eq{"uid": uid, "fuid": fuid}}).success {
        return errors.New("删除用户关注失败")
    }
    return nil
}

// 获取用户关注列表
func GetUserFollowList(ctx context.Context, fuid int64) ([]int64, error) {
    result := selectList(ctx, "user_follow", where{sqlex.Eq{"fuid": fuid}}, "uid")
    if !result.success {
        return nil, errors.New("获取用户关注列表失败")
    }
    b, _ := builder.Get(result.b, "list")
    list := b.([]interface{})
    userList := make([]int64, 0, len(list))
    for _, item := range list {
        uid, _ := builder.Get(item.(builder.Builder), "uid")
        userList = append(userList, uid.(int64))
    }
    return userList, nil
}

// 获取用户粉丝列表
func GetFollowUserList(ctx context.Context, uid int64) ([]int64, error) {
    result := selectList(ctx, "user_follow", where{sqlex.Eq{"uid": uid}}, "fuid")
    if !result.success {
        return nil, errors.New("获取用户关注列表失败")
    }
    b, _ := builder.Get(result.b, "list")
    list := b.([]interface{})
    userList := make([]int64, 0, len(list))
    for _, item := range list {
        uid, _ := builder.Get(item.(builder.Builder), "fuid")
        userList = append(userList, uid.(int64))
    }
    return userList, nil
}

在这里无论是粉丝列表还是关注列表,我们都指定了获取对应的userId列表,而非UserFollow数组。这是为了便于后续dataloader的使用,以后会提到。

到这里用户相关的model就编写完了,后面真正与前端一起联调时,定还有许多更改。而其他诸如文章,评论等的model,便不再赘述。用户相关的model,已经将基本的CURD涵盖。

看完这里,我们可以发现,对于user的扩展表user_count 和 user_follow, 我们并没有在model层面去设计他们的关系,在数据的获取,新增,修改上,也都是独立的。这是因为我们所有定义的数据之间的关系,都交由GraphQL去描述了,在数据层我们反而不用多在意这些关系的实现。

作者个人博客地址: https://unrotten.org

作者微信公众号地址:

ZvuA7br.jpg!web


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK