22

一次构建简单HTTP框架的尝试 —— Golang实现

 4 years ago
source link: https://segmentfault.com/a/1190000021341804
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

程序实现了基本的GET、POST方式路由,不依赖net/http包。

程序只接收content-type为application/json时的POST参数,返回的数据也仅支持json格式。程序仅支持GET、POST方式路由。

router.go:注册路由、启动服务

package http_server

import (
    "fmt"
    "net"
)

type handlerFunc func(*Request)

type methodTree struct {
    method string
    nodes []Node
}

type Node struct {
    path string
    handle handlerFunc
}

type Router struct {
    Trees []methodTree
}

//单条请求数据大小为 40k
const MaxRequestSize = 1024 * 40

func Default() *Router {
    return &Router{}
}

func (r *Router) Run(addr string) error {
    listener, err := net.Listen("tcp4", addr)
    if err != nil {
        return fmt.Errorf("listen error:%v", err)
    }
    for {
        conn, err := listener.Accept()
        if err != nil {
            fmt.Println(err)
            continue
        }
        go r.handle(conn)
    }
}

func (r *Router) handle(conn net.Conn) {
    accept := make(chan readerData)
    go r.parseRequest(accept, conn)
    reader := NewReader(conn, MaxRequestSize)
    err := reader.read(accept)
    if err != nil {
        fmt.Println(err)
        //读取数据失败,响应400 bad request错误
        response := newResponse(conn)
        response.errWrite(400)
        conn.Close()
        close(accept)
    }
}

//监听通道,解析请求数据
func (r *Router) parseRequest(accept chan readerData, conn net.Conn) {
    for {
        data, isOk := <- accept
        if !isOk {
            return
        }
        request := newRequest()
        request.response = newResponse(conn)
        request.parse(data)
        r.handleHTTPRequest(request)
    }
}

//调用函数处理http请求
func (r *Router) handleHTTPRequest(request *Request) {
    httpMethod := request.Method
    for _, tree := range r.Trees {
        if tree.method != httpMethod {
            continue
        }
        for _, node := range tree.nodes {
            if node.path == request.Path {
                node.handle(request)
                request.response.write()
                return
            }
        }
    }
    //未找到任何handle 返回404
    request.response.errWrite(404)
}

//设置post路由
func (r *Router) POST(path string, handle handlerFunc) {
    r.addRoute("POST", path, handle)
}

//设置get路由
func (r *Router) GET(path string, handle handlerFunc) {
    r.addRoute("GET", path, handle)
}

//添加路由
func (r *Router) addRoute(method string, path string, handle handlerFunc) {
    var newNodes bool
    for k, v := range r.Trees {
        if method == v.method {
            r.Trees[k].nodes = append(v.nodes, Node{
                path:   path,
                handle: handle,
            })
            newNodes = true
            break
        }
    }
    if !newNodes {
        tree := methodTree{
            method: method,
        }
        tree.nodes = append(tree.nodes, Node{
            path:   path,
            handle: handle,
        })
        r.Trees = append(r.Trees, tree)
    }
}

reader.go:读取并解析HTTP请求行、请求头、请求体

package http_server

import (
    "bytes"
    "fmt"
    "net"
    "strconv"
    "strings"
)

type Reader struct {
    conn net.Conn
    readerData
    buff    []byte
    buffLen int
    start   int
    end     int
}

type readerData struct {
    Line map[string]string    //请求行
    Header map[string]string    //请求头
    Body string    //请求体
}

//实例化
func NewReader(conn net.Conn, buffLen int) *Reader {
    return &Reader{
        conn: conn,
        readerData: readerData{
            Line:   make(map[string]string),
            Header: make(map[string]string),
        },
        buffLen: buffLen,
        buff: make([]byte, buffLen),
    }
}

//读取并解析请求行
func (reader *Reader) parseLine() (isOK bool, err error) {
    index := bytes.Index(reader.buff, []byte{byte('\r'), byte('\n')})
    if index == -1 {
        //没有解析到\r\n返回继续读取
        return
    }
    //读取请求行
    requestLine := string(reader.buff[:index])
    arr := strings.Split(requestLine, " ")
    if len(arr) != 3 {
        return false, fmt.Errorf("bad request line")
    }
    reader.Line["method"] = arr[0]
    reader.Line["url"] = arr[1]
    reader.Line["version"] = arr[2]

    reader.start = index + 2
    return true, nil
}

//读取并解析请求头
func (reader *Reader) parseHeader() {
    if reader.start == reader.end {
        return
    }
    index := bytes.Index(reader.buff[reader.start:], []byte{byte('\r'), byte('\n'), byte('\r'), byte('\n')})
    if index == -1 {
        return
    }
    headerStr := string(reader.buff[reader.start:reader.start+index])
    requestHeader := strings.Split(headerStr, "\r\n")
    for _, v := range requestHeader {
        arr := strings.Split(v, ":")
        if len(arr) < 2 {
            continue
        }
        reader.Header[strings.ToUpper(arr[0])] = strings.ToLower(strings.Trim(strings.Join(arr[1:], ":"), " "))
    }
    reader.start += index + 4
}

//读取并解析请求体
func (reader *Reader) parseBody() (isOk bool, err error) {
    //判断请求头中是否指明了请求体的数据长度
    contentLenStr, ok := reader.Header["CONTENT-LENGTH"]
    if !ok {
        return false, fmt.Errorf("bad request:no content-length")
    }
    contentLen, err := strconv.ParseInt(contentLenStr, 10, 64)
    if err != nil {
        return false, fmt.Errorf("parse content-length error:%s", contentLenStr)
    }
    if contentLen > int64(reader.end - reader.start) {
        //请求体长度不够,返回继续读取
        return false, nil
    }
    reader.Body = string(reader.buff[reader.start:int64(reader.start)+contentLen])
    return true, nil
}

//读取http请求
func (reader *Reader) read(accept chan readerData) (err error) {
    for  {
        if reader.end == reader.buffLen {
            //缓冲区的容量存不了一条请求的数据
            return fmt.Errorf("request is too large:%v", reader)
        }
        buffLen, err := reader.conn.Read(reader.buff)
        if err != nil {
            //连接关闭了
            return nil
        }
        reader.end += buffLen

        //解析请求行
        isOk, err := reader.parseLine()
        if err != nil {
            return fmt.Errorf("parse request line error:%v", err)
        }
        if !isOk {
            continue
        }
        //解析请求头
        reader.parseHeader()
        //如果是post请求,解析请求体
        if len(reader.Header) > 0 && strings.EqualFold(strings.ToUpper(reader.Line["method"]), "POST") {
            isOk, err := reader.parseBody()
            if err != nil {
                return fmt.Errorf("parse request body error:%v", err)
            }
            //读取http请求体未成功
            if !isOk {
                reader.start = 0
                reader.Line = make(map[string]string)
                reader.Header = make(map[string]string)
                continue
            }
        }
        accept <- reader.readerData
        reader.move()
    }
}

//前移上一次未处理完的数据
func (reader *Reader) move() {
    if reader.start == 0 {
        return
    }
    copy(reader.buff, reader.buff[reader.start:reader.end])
    reader.end -= reader.start
    reader.start = 0
}

request.go:解析请求头、请求参数等

package http_server

import (
    "encoding/json"
    "strings"
)

type H map[string]interface{}

type Request struct {
    Path string
    Method string
    headers map[string]string
    queries map[string]string
    posts map[string]string
    *response
}

func newRequest() *Request {
    return &Request{
        headers: make(map[string]string),
        queries: make(map[string]string),
        posts: make(map[string]string),
    }
}

//解析请求内容
func (request *Request) parse(readerData readerData) {
    request.Method = readerData.Line["method"]
    request.headers = readerData.Header

    //解析请求path和get参数
    var queries string
    index := strings.Index(readerData.Line["url"], "?")
    if index == -1 {
        request.Path = readerData.Line["url"]
    }else {
        request.Path = readerData.Line["url"][:index]
        queries = readerData.Line["url"][index+1:]
    }
    if request.Method == "GET" {
        //解析get请求参数
        if queries != "" {
            q := strings.Split(queries, "&")
            for _, v := range q {
                param := strings.Split(v, "=")
                request.queries[param[0]] = param[1]
            }
        }
    }else {
        //判断content-type类型是不是 application/json
        contentTypes, isExist := request.headers["CONTENT-TYPE"]
        if isExist {
            cTypeArr := strings.Split(contentTypes, ";")
            if strings.EqualFold(cTypeArr[0], "application/json") {
                //解析post请求参数
                json.Unmarshal([]byte(readerData.Body), &(request.posts))
            }
        }
    }
}

//获取get请求参数
func (request *Request) Query(name string) string {
    val, isExist := request.queries[name]
    if isExist {
        return val
    }
    return ""
}

//获取post请求参数
func (request *Request) Post(name string) string {
    val, isExist := request.posts[name]
    if isExist {
        return val
    }
    return ""
}

//获取get请求参数
func (request *Request) DefaultQuery(name, def string) string {
    val, isExist := request.queries[name]
    if isExist {
        return val
    }
    return def
}

//获取post请求参数
func (request *Request) DefaultPost(name, def string) string {
    val, isExist := request.posts[name]
    if isExist {
        return val
    }
    return def
}

//获取请求头
func (request *Request) GetHeader(name string) string {
    val, isExist := request.posts[strings.ToUpper(name)]
    if isExist {
        return val
    }
    return ""
}

//设置要返回的json数据
func (request *Request) Json(code int, obj interface{}) {
    ret, err := json.Marshal(obj)
    if err == nil {
        //设置content-length
        request.response.bodyLen = len(ret)
        request.response.body = ret
    }
    request.response.status = code
}

//设置响应头
func (request *Request) Header(name string, val string) {
    if _, isExist := request.response.headers[name]; !isExist {
        request.response.headers[strings.ToLower(name)] = val
    }
}

response.go:构造HTTP响应

package http_server

import (
    "fmt"
    "net"
    "strconv"
)

type response struct {
    status int
    body []byte
    bodyLen int
    headers map[string]string
    buff []byte
    conn net.Conn
}

func newResponse (conn net.Conn) *response {
    return &response{
        conn:    conn,
        headers: make(map[string]string),
    }
}

//响应行
func (response *response) writeLine() {
    line := fmt.Sprintf("HTTP/1.1 %d OK\r\n", response.status)
    response.buff = append(response.buff, []byte(line)...)
}

//响应头
func (response *response) writeHeader() {
    response.headers["server"] = "^_^"
    response.headers["content-type"] = "application/json"
    response.headers["content-length"] = strconv.FormatInt(int64(response.bodyLen), 10)
    for k, v := range response.headers {
        response.buff = append(response.buff, []byte(fmt.Sprintf("%s: %v\r\n", k, v))...)
    }
    response.buff = append(response.buff, []byte("\r\n")...)
}

func (response *response) write() {
    response.writeLine()
    response.writeHeader()
    response.buff = append(response.buff, response.body...)
    response.conn.Write(response.buff)
}

func (response *response) errWrite(status int) {
    response.status = status
    response.body = []byte("Request Error")
    response.bodyLen = len(response.body)
    response.write()
}

项目放在 Github 上,欢迎给star~


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK