diff options
author | Ethel Morgan <eth@ethulhu.co.uk> | 2020-06-24 12:10:57 +0100 |
---|---|---|
committer | Ethel Morgan <eth@ethulhu.co.uk> | 2020-06-24 12:10:57 +0100 |
commit | aa380da6a61f9b29ee263d95d17a2953a0528b28 (patch) | |
tree | c2915f6508ed321eeacb6ca3c7ba8de313c0ce11 | |
parent | 02f268cc3ba056be15097a7185a367585dd05275 (diff) |
import package flag from helixv0.0.1
-rw-r--r-- | commandline.go | 46 | ||||
-rw-r--r-- | doc.go | 33 | ||||
-rw-r--r-- | flagset.go | 74 | ||||
-rw-r--r-- | flagset_test.go | 34 | ||||
-rw-r--r-- | go.mod | 7 | ||||
-rw-r--r-- | parsefuncs.go | 48 |
6 files changed, 242 insertions, 0 deletions
diff --git a/commandline.go b/commandline.go new file mode 100644 index 0000000..db3d413 --- /dev/null +++ b/commandline.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +package flag + +import ( + "fmt" + "os" + "time" +) + +var ( + CommandLine = NewFlagSet(os.Args[0], ExitOnError) + Usage = func() { + fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0]) + CommandLine.PrintDefaults() + } +) + +func init() { + CommandLine.Usage = runUsageVariable +} +func runUsageVariable() { + Usage() +} + +func Parse() { + _ = CommandLine.Parse(os.Args[1:]) +} + +func String(flagName, defaultValue, description string) *string { + return CommandLine.String(flagName, defaultValue, description) +} +func Int(flagName string, defaultValue int, description string) *int { + return CommandLine.Int(flagName, defaultValue, description) +} +func Duration(flagName string, defaultValue time.Duration, description string) *time.Duration { + return CommandLine.Duration(flagName, defaultValue, description) +} +func Custom(flagName, defaultValue, description string, parser ParseFunc) *interface{} { + return CommandLine.Custom(flagName, defaultValue, description, parser) +} +func Bool(flagName string, defaultValue bool, description string) *bool { + return CommandLine.Bool(flagName, defaultValue, description) +} @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +/* +Package flag wraps Go's built-in flag package, with the addition of idiomatic custom flags. + +Custom Flags + +Custom flags are wrappers around Go's built-in string flags, with a parser +func. They can be used to parse custom flag types, or to have custom flag +validators, while keeping the parsing & validation with the flag's definition. + + var ( + urlFlag = flag.Custom("url", "", "url to GET", func(raw string) (interface{}, error) { + return url.Parse(raw) + }) + outputFlag = flag.Custom("output", "", "output format", func(raw string) (interface{}, error) { + if !(raw == "table" || raw == "json") { + return nil, fmt.Errorf("must be either json or table, got %v", raw) + } + return raw, nil + }) + ) + + func main() { + flag.Parse() + urlFlag := (*urlFlag).(*url.URL) + outputFlag := (*outputFlag).(string) + } + +*/ +package flag diff --git a/flagset.go b/flagset.go new file mode 100644 index 0000000..6118177 --- /dev/null +++ b/flagset.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +package flag + +import ( + "flag" + "fmt" + "os" +) + +type ( + ErrorHandling = flag.ErrorHandling + + ParseFunc func(string) (interface{}, error) + + FlagSet struct { + flag.FlagSet + + customFlags []func() error + } +) + +const ( + ContinueOnError = flag.ContinueOnError + ExitOnError = flag.ExitOnError + PanicOnError = flag.PanicOnError +) + +func (f *FlagSet) Parse(arguments []string) error { + if err := f.FlagSet.Parse(arguments); err != nil { + return err + } + + for _, customFlag := range f.customFlags { + if err := customFlag(); err != nil { + switch f.FlagSet.ErrorHandling() { + case flag.ContinueOnError: + return err + case flag.ExitOnError: + fmt.Fprintf(os.Stdout, "%v\n\n", err) + f.Usage() + os.Exit(2) + case flag.PanicOnError: + panic(err) + } + } + } + return nil +} + +func NewFlagSet(name string, handling ErrorHandling) *FlagSet { + return &FlagSet{ + FlagSet: *flag.NewFlagSet(name, handling), + } +} + +func (f *FlagSet) Custom(flagName, defaultValue, description string, parser ParseFunc) *interface{} { + rawFlag := f.String(flagName, defaultValue, description) + + var value interface{} + + f.customFlags = append(f.customFlags, func() error { + var err error + value, err = parser(*rawFlag) + if err != nil { + return fmt.Errorf("invalid value %q for flag -%s: %w", *rawFlag, flagName, err) + } + return nil + }) + + return &value +} diff --git a/flagset_test.go b/flagset_test.go new file mode 100644 index 0000000..bef4dd2 --- /dev/null +++ b/flagset_test.go @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +package flag + +import ( + "testing" + "time" +) + +func TestFlagSetCustom(t *testing.T) { + fs := NewFlagSet("test", ContinueOnError) + + i := fs.Int("i", 0, "i") + d := fs.Custom("d", "1s", "d", func(raw string) (interface{}, error) { + return time.ParseDuration(raw) + }) + e := fs.Custom("e", "", "e", StringEnum("json", "table")) + + if err := fs.Parse([]string{"-i", "12", "-d", "3m", "-e", "json"}); err != nil { + t.Fatalf("got error: %v", err) + } + + if *i != 12 { + t.Errorf("-i == %v, wanted %v", *i, 12) + } + if (*d).(time.Duration) != 3*time.Minute { + t.Errorf("-d == %v, wanted %v", *d, 3*time.Minute) + } + if (*e).(string) != "json" { + t.Errorf("-e == %v, wanted %v", *e, "json") + } +} @@ -0,0 +1,7 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +module go.eth.moe/flag + +go 1.14 diff --git a/parsefuncs.go b/parsefuncs.go new file mode 100644 index 0000000..0263a5a --- /dev/null +++ b/parsefuncs.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +package flag + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +func RequiredString(raw string) (interface{}, error) { + if raw == "" { + return nil, errors.New("must not be empty") + } + return raw, nil +} + +func StringEnum(values ...string) ParseFunc { + return func(raw string) (interface{}, error) { + for _, value := range values { + if value == raw { + return raw, nil + } + } + return raw, fmt.Errorf("must be one of %q", values) + } +} + +func IntList(raw string) (interface{}, error) { + var ints []int + + if raw == "" { + return ints, nil + } + + for _, raw := range strings.Split(raw, ",") { + i, err := strconv.Atoi(raw) + if err != nil { + return ints, err + } + ints = append(ints, i) + } + + return ints, nil +} |