在使用go-Micro微服务框架,编写api中间层的逻辑时,需要从micro api的request中获取到get或者post的请求参数,然后进行参数的合法性校验。
封装一个函数使用反射,获取到参数,以及校验请求方法,校验参数合法性,报错返回相应的报错信息。
其中校验部分使用到了开源项目 gopkg.in/go-playgrou…
package vodutils
import (
"errors"
go_api "github.com/micro/go-micro/api/proto"
microApi "github.com/micro/go-micro/api/proto"
"pkg.lwbedu.com/gk-micro/common/errs"
"pkg.lwbedu.com/gk-micro/common/net/api"
"reflect"
"strconv"
"strings"
)
const GET string = "GET"
const POST string = "POST"
var TagKey = "param"
var MethodNotAllow = errors.New("method not allowed")
var NotStruct = errors.New("not struct")
var NotPointer = errors.New("not ptr")
//传入需要校验的方法,需要获取的参数对象指针,以及请求对象
func ParamValid(m string, getParam interface{}, req *microApi.Request) error {
// 指针的value
pv := reflect.ValueOf(getParam)
// 非指针返回
if pv.Kind() != reflect.Ptr {
return errs.BadRequest(NotPointer.Error())
}
if pv.IsNil() {
return errs.BadRequest("param is nil")
}
//获取到指针对应的value
pv = pv.Elem()
//指针的类型
if pv.Kind() != reflect.Struct {
return errs.BadRequest(NotStruct.Error())
}
//value对应的type
pt := pv.Type()
method := strings.ToUpper(m)
if method != strings.ToUpper(req.Method) {
return errs.BadRequest(MethodNotAllow.Error())
}
var params = map[string]*go_api.Pair{}
if method == GET {
params = req.Get
} else if method == POST {
params = req.Post
} else {
return errs.BadRequest(MethodNotAllow.Error())
}
for i := 0; i < pt.NumField(); i++ {
tyfield := pt.Field(i)
name := tyfield.Name
filed := pv.FieldByName(name)
tag := tyfield.Tag.Get(TagKey)
//获得该参数
paramVals := params[tag].Values
//如果长度为1,则只有一个参数
if len(paramVals) == 1 {
paramVal := paramVals[0]
if filed.CanSet() {
switch filed.Kind() {
case reflect.String:
filed.SetString(paramVal)
case reflect.Int:
fallthrough
case reflect.Int8:
fallthrough
case reflect.Int16:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Uint:
fallthrough
case reflect.Uint8:
fallthrough
case reflect.Uint16:
fallthrough
case reflect.Uint32:
fallthrough
case reflect.Uint64:
fallthrough
case reflect.Int64:
val, _ := strconv.ParseInt(paramVal, 10, 64)
filed.SetInt(val)
case reflect.Float32:
fallthrough
case reflect.Float64:
val, _ := strconv.ParseFloat(paramVal, 64)
filed.SetFloat(val)
}
}
} else {
//非1的话,为列表,需要赋值给相应的列表
}
}
//校验参数
err := Validate(pv.Interface())
if err != nil {
return errs.BadRequest(err.Error())
}
return nil
}
// validate the params and return the result of validate
func Validate(params interface{}) error {
val := validator.New()
err := val.Struct(params)
validRes := ValidateMsg(err)
if validRes != "" {
return errors.New(validRes)
}
return nil
}
// Beautify validate message
func ValidateMsg(err error) string {
var errMsg = ""
if err == nil {
return errMsg
}
validationErrors, ok := err.(validator.ValidationErrors)
if !ok {
errMsg = "invalid params"
return errMsg
}
for _, val := range validationErrors {
switch val.Tag() {
case "required":
errMsg = val.Field() + " are required"
case "min":
errMsg = val.Field() + " should greater or equal to " + val.Param()
case "max":
errMsg = val.Field() + " should less than or equal to " + val.Param()
case "required_with":
errMsg = val.Field() + " should be present and not empty only if any of '" + val.Param() + "' are present"
case "required_with_all":
errMsg = val.Field() + " should be present and not empty only if all of '" + val.Param() + "' are present"
case "required_without":
errMsg = val.Field() + " should be present and not empty only when any of '" + val.Param() + "' are not present"
case "required_without_all":
errMsg = val.Field() + " should be present and not empty only when all of '" + val.Param() + "' are not present"
case "len":
errMsg = val.Field() + "'s length should be equal to " + val.Param()
case "eq":
errMsg = val.Field() + " should be equal to " + val.Param()
case "ne":
errMsg = val.Field() + " should be not equal to " + val.Param()
case "oneof":
errMsg = val.Field() + " should be one of the values in " + val.Param()
case "gt":
errMsg = val.Field() + " should be greater than " + val.Param()
case "gte":
errMsg = val.Field() + " should be greater than or equal to " + val.Param()
case "lt":
errMsg = val.Field() + " should be less than " + val.Param()
case "lte":
errMsg = val.Field() + " should be less than or equal to " + val.Param()
case "numeric":
errMsg = val.Field() + " should be supposed to a numeric"
case "email":
errMsg = val.Field() + " should be supposed to a email address"
case "url":
errMsg = val.Field() + " should be supposed to a url"
}
}
return errMsg
}