go反射实现micro request参数请求及校验

1,642 阅读2分钟

在使用go-Micro微服务框架,编写api中间层的逻辑时,需要从micro api的request中获取到get或者post的请求参数,然后进行参数的合法性校验。

封装一个函数使用反射,获取到参数,以及校验请求方法,校验参数合法性,报错返回相应的报错信息。

其中校验部分使用到了开源项目 gopkg.in/go-playgrou…

package vodutils

import (
	"errors"
	go_api "github.com/micro/go-micro/api/proto"
	microApi "github.com/micro/go-micro/api/proto"
	"pkg.lwbedu.com/gk-micro/common/errs"
	"pkg.lwbedu.com/gk-micro/common/net/api"
	"reflect"
	"strconv"
	"strings"
)

const GET string = "GET"
const POST string = "POST"

var TagKey = "param"
var MethodNotAllow = errors.New("method not allowed")
var NotStruct = errors.New("not struct")
var NotPointer = errors.New("not ptr")

//传入需要校验的方法,需要获取的参数对象指针,以及请求对象
func ParamValid(m string, getParam interface{}, req *microApi.Request) error {
	// 指针的value
	pv := reflect.ValueOf(getParam)
	// 非指针返回
	if pv.Kind() != reflect.Ptr {
		return errs.BadRequest(NotPointer.Error())
	}
	if pv.IsNil() {
		return errs.BadRequest("param is nil")
	}
	//获取到指针对应的value
	pv = pv.Elem()
	//指针的类型
	if pv.Kind() != reflect.Struct {
		return errs.BadRequest(NotStruct.Error())
	}
	//value对应的type
	pt := pv.Type()

	method := strings.ToUpper(m)
	if method != strings.ToUpper(req.Method) {
		return errs.BadRequest(MethodNotAllow.Error())
	}

	var params = map[string]*go_api.Pair{}

	if method == GET {
		params = req.Get
	} else if method == POST {
		params = req.Post
	} else {
		return errs.BadRequest(MethodNotAllow.Error())
	}
	for i := 0; i < pt.NumField(); i++ {
		tyfield := pt.Field(i)
		name := tyfield.Name
		filed := pv.FieldByName(name)
		tag := tyfield.Tag.Get(TagKey)
		//获得该参数
		paramVals := params[tag].Values
		//如果长度为1,则只有一个参数
		if len(paramVals) == 1 {
			paramVal := paramVals[0]
			if filed.CanSet() {
				switch filed.Kind() {
				case reflect.String:
					filed.SetString(paramVal)
				case reflect.Int:
					fallthrough
				case reflect.Int8:
					fallthrough
				case reflect.Int16:
					fallthrough
				case reflect.Int32:
					fallthrough
				case reflect.Uint:
					fallthrough
				case reflect.Uint8:
					fallthrough
				case reflect.Uint16:
					fallthrough
				case reflect.Uint32:
					fallthrough
				case reflect.Uint64:
					fallthrough
				case reflect.Int64:
					val, _ := strconv.ParseInt(paramVal, 10, 64)
					filed.SetInt(val)
				case reflect.Float32:
					fallthrough
				case reflect.Float64:
					val, _ := strconv.ParseFloat(paramVal, 64)
					filed.SetFloat(val)
				}
			}
		} else {
			//非1的话,为列表,需要赋值给相应的列表

		}
	}
	//校验参数
	err := Validate(pv.Interface())
	if err != nil {
		return errs.BadRequest(err.Error())
	}
	return nil
}


// validate the params and return the result of validate
func Validate(params interface{}) error {
	val := validator.New()
	err := val.Struct(params)

	validRes := ValidateMsg(err)
	if validRes != "" {
		return errors.New(validRes)
	}

	return nil
}

// Beautify validate message
func ValidateMsg(err error) string {
	var errMsg = ""

	if err == nil {
		return errMsg
	}

	validationErrors, ok := err.(validator.ValidationErrors)
	if !ok {
		errMsg = "invalid params"
		return errMsg
	}

	for _, val := range validationErrors {
		switch val.Tag() {
		case "required":
			errMsg = val.Field() + " are required"
		case "min":
			errMsg = val.Field() + " should greater or equal to " + val.Param()
		case "max":
			errMsg = val.Field() + " should less than or equal to " + val.Param()
		case "required_with":
			errMsg = val.Field() + " should be present and not empty only if any of '" + val.Param() + "' are present"
		case "required_with_all":
			errMsg = val.Field() + " should be present and not empty only if all of '" + val.Param() + "' are present"
		case "required_without":
			errMsg = val.Field() + " should be present and not empty only when any of '" + val.Param() + "' are not present"
		case "required_without_all":
			errMsg = val.Field() + " should be present and not empty only when all of '" + val.Param() + "' are not present"
		case "len":
			errMsg = val.Field() + "'s length should be equal to " + val.Param()
		case "eq":
			errMsg = val.Field() + " should be equal to " + val.Param()
		case "ne":
			errMsg = val.Field() + " should be not equal to " + val.Param()
		case "oneof":
			errMsg = val.Field() + " should be one of the values in " + val.Param()
		case "gt":
			errMsg = val.Field() + " should be greater than " + val.Param()
		case "gte":
			errMsg = val.Field() + " should be greater than or equal to " + val.Param()
		case "lt":
			errMsg = val.Field() + " should be less than " + val.Param()
		case "lte":
			errMsg = val.Field() + " should be less than or equal to " + val.Param()
		case "numeric":
			errMsg = val.Field() + " should be supposed to a numeric"
		case "email":
			errMsg = val.Field() + " should be supposed to a email address"
		case "url":
			errMsg = val.Field() + " should be supposed to a url"
		}
	}

	return errMsg
}