Skip to content

Commit

Permalink
refactor: xcmd support pre middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
hui.wang committed Jan 19, 2022
1 parent 7aac31b commit 2835e56
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 97 deletions.
110 changes: 35 additions & 75 deletions xcmd/commander.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@ package xcmd

import (
"context"
"errors"
"flag"
"fmt"
"io"
"os"
"reflect"
"strings"

"github.com/sandwich-go/xconf"
"github.com/sandwich-go/xconf/xflag"
"github.com/sandwich-go/xconf/xutil"
)

type MiddlewareFunc = func(ctx context.Context, c *Command, ff *flag.FlagSet, args []string, next Executer) error
Expand All @@ -22,6 +18,7 @@ type Command struct {
Output io.Writer
commands []*Command
middleware []MiddlewareFunc
preMiddleware []MiddlewareFunc
usageNamePath []string
}

Expand All @@ -46,15 +43,21 @@ func (c *Command) Use(middleware ...MiddlewareFunc) *Command {
return c
}

func (c *Command) UsePre(preMiddleware ...MiddlewareFunc) *Command {
c.preMiddleware = append(c.preMiddleware, preMiddleware...)
return c
}

func (c *Command) AddCommand(sub *Command, middleware ...MiddlewareFunc) {
sub.usageNamePath = append(c.usageNamePath, sub.usageNamePath...)
sub.middleware = c.combineMiddlewareFunc(middleware...)
sub.middleware = combineMiddlewareFunc(c.middleware, middleware...)
sub.preMiddleware = combineMiddlewareFunc(c.preMiddleware, sub.preMiddleware...)
c.commands = append(c.commands, sub)
}

func (c *Command) combineMiddlewareFunc(middleware ...MiddlewareFunc) []MiddlewareFunc {
m := make([]MiddlewareFunc, 0, len(c.middleware)+len(middleware))
m = append(m, c.middleware...)
func combineMiddlewareFunc(middlewareNow []MiddlewareFunc, middleware ...MiddlewareFunc) []MiddlewareFunc {
m := make([]MiddlewareFunc, 0, len(middlewareNow)+len(middleware))
m = append(m, middlewareNow...)
m = append(m, middleware...)
return m
}
Expand All @@ -72,60 +75,6 @@ func (c *Command) wrapErr(err error) error {
return fmt.Errorf("command: %s err:%s", strings.Join(c.usageNamePath, " "), err.Error())
}

func (c *Command) ApplyArgs(ff *flag.FlagSet, args ...string) error {
// 默认 usage 无参
ff.Usage = func() {
c.Explain(c.Output)
xflag.PrintDefaults(ff)
}
if c.cc.GetBind() == nil {
return ff.Parse(args)
}
cc := xconf.NewOptions(
xconf.WithErrorHandling(xconf.ContinueOnError),
xconf.WithFlagSet(ff),
xconf.WithFlagArgs(args...))
cc.ApplyOption(c.cc.GetXConfOption()...)

// 获取bindto结构合法的FieldPath,并过滤合法的BindToFieldPath
_, fieldsMap := xconf.NewStruct(
reflect.New(reflect.ValueOf(c.cc.GetBind()).Type().Elem()).Interface(),
cc.TagName,
cc.TagNameForDefaultValue,
cc.FieldTagConvertor).Map()
var ignorePath []string
if len(c.cc.GetBindFieldPath()) > 0 {
for k := range fieldsMap {
if !xutil.ContainStringEqualFold(c.cc.GetBindFieldPath(), k) {
ignorePath = append(ignorePath, k)
}
}
}
var invalidKeys []string
for _, v := range c.cc.GetBindFieldPath() {
if _, ok := fieldsMap[v]; !ok {
invalidKeys = append(invalidKeys, v)
}
}

if len(invalidKeys) > 0 {
return c.wrapErr(fmt.Errorf("option BindFieldPath has invalid item:%s", strings.Join(invalidKeys, ",")))
}

cc.ApplyOption(xconf.WithFlagCreateIgnoreFiledPath(ignorePath...))
x := xconf.NewWithConf(cc)
// Available Commands + Flags
cc.FlagSet.Usage = func() {
c.Explain(c.Output)
x.UsageToWriter(c.Output, args...)
}
err := x.Parse(c.cc.GetBind())
if err != nil {
return fmt.Errorf("got err:%s while Parse", err.Error())
}
return nil
}

func (c *Command) Execute(ctx context.Context, args ...string) error {
if len(args) != 0 {
// 尝试在当前命令集下寻找子命令
Expand All @@ -138,25 +87,29 @@ func (c *Command) Execute(ctx context.Context, args ...string) error {
}
}
ff := flag.NewFlagSet(strings.Join(c.usageNamePath, "/"), flag.ContinueOnError)
return ChainMiddleware(c.middleware...)(ctx, c, ff, args, exec)
var allMiddlewares []MiddlewareFunc
allMiddlewares = append(allMiddlewares, c.preMiddleware...)
allMiddlewares = append(allMiddlewares, c.cc.GetParser())
allMiddlewares = append(allMiddlewares, c.middleware...)
return ChainMiddleware(allMiddlewares...)(ctx, c, ff, args, exec)
}

func exec(ctx context.Context, c *Command, ff *flag.FlagSet, args []string) error {
// arg为空或子命令没有找到,说明传入的参数不是命令名称,在当前层执行
if err := c.ApplyArgs(ff, args...); err != nil {
if IsErrHelp(err) {
return nil
}
return fmt.Errorf("[ApplyArgs] %s", err.Error())
executer := c.Config().GetExecute()
if executer == nil {
executer = c.cc.GetOnExecuterLost()
}
return c.Config().GetExecute()(context.Background(), c, ff, args)
return executer(context.Background(), c, ff, args)
}

func (c *Command) GetName() string { return c.name }
func (c *Command) Usage() string { return c.cc.GetSynopsis() }

func (c *Command) CommandInheritBind(name string, opts ...ConfigOption) *Command {
cc := NewConfig(WithBind(c.cc.GetBind()), WithBindFieldPath(c.cc.GetBindFieldPath()...))
func (c *Command) SubCommand(name string, opts ...ConfigOption) *Command {
cc := NewConfig(
WithBind(c.cc.GetBind()),
WithBindFieldPath(c.cc.GetBindFieldPath()...),
WithXConfOption(c.cc.GetXConfOption()...),
)
cc.ApplyOption(opts...)
sub := NewCommandWithConfig(name, cc)
c.AddCommand(sub)
Expand All @@ -165,7 +118,14 @@ func (c *Command) CommandInheritBind(name string, opts ...ConfigOption) *Command

func (c *Command) Check() error {
for _, v := range c.commands {
err := v.ApplyArgs(flag.NewFlagSet(strings.Join(c.usageNamePath, "/"), flag.ContinueOnError))
binder := c.cc.GetParser()
if binder == nil {
return errors.New("need Parser")
}
ff := flag.NewFlagSet(strings.Join(v.usageNamePath, "/"), flag.ContinueOnError)
err := binder(context.Background(), v, ff, nil, func(ctx context.Context, c *Command, ff *flag.FlagSet, args []string) error {
return nil
})
if err != nil {
return err
}
Expand Down
6 changes: 4 additions & 2 deletions xcmd/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ var rootCmd = NewCommand(path.Base(os.Args[0]), WithExecute(func(ctx context.Con
}))

func Use(middleware ...MiddlewareFunc) *Command { return rootCmd.Use(middleware...) }
func UsePre(middleware ...MiddlewareFunc) *Command { return rootCmd.UsePre(middleware...) }
func AddCommand(sub *Command, middleware ...MiddlewareFunc) { rootCmd.AddCommand(sub, middleware...) }
func Config() ConfigInterface { return rootCmd.cc }
func Execute(ctx context.Context, args ...string) error { return rootCmd.Execute(ctx, args...) }
func Explain(w io.Writer) { rootCmd.Explain(w) }
func Check() error { return rootCmd.Check() }

func CommandInheritBind(name string, opts ...ConfigOption) *Command {
return rootCmd.CommandInheritBind(name, opts...)
func SubCommand(name string, opts ...ConfigOption) *Command {
return rootCmd.SubCommand(name, opts...)
}

func SetRootCommand(root *Command) { rootCmd = root }
func RootCommand() *Command { return rootCmd }
2 changes: 1 addition & 1 deletion xcmd/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func applyPadding(filler string) string {
const magic = "\x00"

func PrintCommand(c *Command, lvl []bool) (lines []string) {
lines = append(lines, fmt.Sprintf("%s%s(%d middleware) %s %s", getPrefix(lvl), c.name, len(c.middleware), magic, c.cc.GetSynopsis()))
lines = append(lines, fmt.Sprintf("%s%s(%d,%d) %s %s", getPrefix(lvl), c.name, len(c.preMiddleware), len(c.middleware), magic, c.cc.GetSynopsis()))
var level = append(lvl, false)
for i := 0; i < len(c.commands); i++ {
if i+1 == len(c.commands) {
Expand Down
50 changes: 42 additions & 8 deletions xcmd/gen_config_optiongen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions xcmd/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func main() {
// sub命令export,继承上游命令的bind信息
// export 派生go命令,只绑定http_address字段
// go 派生export命令,追加绑定timeouts字段
xcmd.CommandInheritBind("export",
xcmd.SubCommand("export",
xcmd.WithSynopsis("export proto to golang/cs/python/lua"),
xcmd.WithExecute(func(ctx context.Context, c *xcmd.Command, ff *flag.FlagSet, args []string) error {
fmt.Println("export command")
Expand All @@ -30,13 +30,13 @@ func main() {
).Use(func(ctx context.Context, c *xcmd.Command, ff *flag.FlagSet, args []string, next xcmd.Executer) error {
return next(ctx, c, ff, args)
}).
CommandInheritBind("go",
SubCommand("go",
xcmd.WithBindFieldPath("http_address"),
xcmd.WithSynopsis("generate golang code"),
).Use(func(ctx context.Context, c *xcmd.Command, ff *flag.FlagSet, args []string, next xcmd.Executer) error {
return next(ctx, c, ff, args)
}).
CommandInheritBind("service", xcmd.WithBindFieldPathAppend("timeouts"))
SubCommand("service", xcmd.WithBindFieldPathAppend("timeouts"))

// sub命令log绑定到新的配置项
anotherBind := xcmdtest.NewLog()
Expand Down Expand Up @@ -70,9 +70,10 @@ func main() {
fmt.Println("manual command got log_level:", logLevel)
return nil
}))
xcmd.AddCommand(manual, binding)
manual.UsePre(binding)
xcmd.AddCommand(manual)

panicPrintErr("comamnd check with err: %v", xcmd.Check())
// panicPrintErr("comamnd check with err: %v", xcmd.Check())
panicPrintErr("comamnd Execute with err: %v", xcmd.Execute(context.Background(), os.Args[1:]...))

}
Expand Down
Loading

0 comments on commit 2835e56

Please sign in to comment.