summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEthel Morgan <eth@ethulhu.co.uk>2020-06-24 12:10:57 +0100
committerEthel Morgan <eth@ethulhu.co.uk>2020-06-24 12:10:57 +0100
commitaa380da6a61f9b29ee263d95d17a2953a0528b28 (patch)
treec2915f6508ed321eeacb6ca3c7ba8de313c0ce11
parent02f268cc3ba056be15097a7185a367585dd05275 (diff)
import package flag from helixv0.0.1
-rw-r--r--commandline.go46
-rw-r--r--doc.go33
-rw-r--r--flagset.go74
-rw-r--r--flagset_test.go34
-rw-r--r--go.mod7
-rw-r--r--parsefuncs.go48
6 files changed, 242 insertions, 0 deletions
diff --git a/commandline.go b/commandline.go
new file mode 100644
index 0000000..db3d413
--- /dev/null
+++ b/commandline.go
@@ -0,0 +1,46 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+package flag
+
+import (
+ "fmt"
+ "os"
+ "time"
+)
+
+var (
+ CommandLine = NewFlagSet(os.Args[0], ExitOnError)
+ Usage = func() {
+ fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0])
+ CommandLine.PrintDefaults()
+ }
+)
+
+func init() {
+ CommandLine.Usage = runUsageVariable
+}
+func runUsageVariable() {
+ Usage()
+}
+
+func Parse() {
+ _ = CommandLine.Parse(os.Args[1:])
+}
+
+func String(flagName, defaultValue, description string) *string {
+ return CommandLine.String(flagName, defaultValue, description)
+}
+func Int(flagName string, defaultValue int, description string) *int {
+ return CommandLine.Int(flagName, defaultValue, description)
+}
+func Duration(flagName string, defaultValue time.Duration, description string) *time.Duration {
+ return CommandLine.Duration(flagName, defaultValue, description)
+}
+func Custom(flagName, defaultValue, description string, parser ParseFunc) *interface{} {
+ return CommandLine.Custom(flagName, defaultValue, description, parser)
+}
+func Bool(flagName string, defaultValue bool, description string) *bool {
+ return CommandLine.Bool(flagName, defaultValue, description)
+}
diff --git a/doc.go b/doc.go
new file mode 100644
index 0000000..fa9c9fd
--- /dev/null
+++ b/doc.go
@@ -0,0 +1,33 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+/*
+Package flag wraps Go's built-in flag package, with the addition of idiomatic custom flags.
+
+Custom Flags
+
+Custom flags are wrappers around Go's built-in string flags, with a parser
+func. They can be used to parse custom flag types, or to have custom flag
+validators, while keeping the parsing & validation with the flag's definition.
+
+ var (
+ urlFlag = flag.Custom("url", "", "url to GET", func(raw string) (interface{}, error) {
+ return url.Parse(raw)
+ })
+ outputFlag = flag.Custom("output", "", "output format", func(raw string) (interface{}, error) {
+ if !(raw == "table" || raw == "json") {
+ return nil, fmt.Errorf("must be either json or table, got %v", raw)
+ }
+ return raw, nil
+ })
+ )
+
+ func main() {
+ flag.Parse()
+ urlFlag := (*urlFlag).(*url.URL)
+ outputFlag := (*outputFlag).(string)
+ }
+
+*/
+package flag
diff --git a/flagset.go b/flagset.go
new file mode 100644
index 0000000..6118177
--- /dev/null
+++ b/flagset.go
@@ -0,0 +1,74 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+package flag
+
+import (
+ "flag"
+ "fmt"
+ "os"
+)
+
+type (
+ ErrorHandling = flag.ErrorHandling
+
+ ParseFunc func(string) (interface{}, error)
+
+ FlagSet struct {
+ flag.FlagSet
+
+ customFlags []func() error
+ }
+)
+
+const (
+ ContinueOnError = flag.ContinueOnError
+ ExitOnError = flag.ExitOnError
+ PanicOnError = flag.PanicOnError
+)
+
+func (f *FlagSet) Parse(arguments []string) error {
+ if err := f.FlagSet.Parse(arguments); err != nil {
+ return err
+ }
+
+ for _, customFlag := range f.customFlags {
+ if err := customFlag(); err != nil {
+ switch f.FlagSet.ErrorHandling() {
+ case flag.ContinueOnError:
+ return err
+ case flag.ExitOnError:
+ fmt.Fprintf(os.Stdout, "%v\n\n", err)
+ f.Usage()
+ os.Exit(2)
+ case flag.PanicOnError:
+ panic(err)
+ }
+ }
+ }
+ return nil
+}
+
+func NewFlagSet(name string, handling ErrorHandling) *FlagSet {
+ return &FlagSet{
+ FlagSet: *flag.NewFlagSet(name, handling),
+ }
+}
+
+func (f *FlagSet) Custom(flagName, defaultValue, description string, parser ParseFunc) *interface{} {
+ rawFlag := f.String(flagName, defaultValue, description)
+
+ var value interface{}
+
+ f.customFlags = append(f.customFlags, func() error {
+ var err error
+ value, err = parser(*rawFlag)
+ if err != nil {
+ return fmt.Errorf("invalid value %q for flag -%s: %w", *rawFlag, flagName, err)
+ }
+ return nil
+ })
+
+ return &value
+}
diff --git a/flagset_test.go b/flagset_test.go
new file mode 100644
index 0000000..bef4dd2
--- /dev/null
+++ b/flagset_test.go
@@ -0,0 +1,34 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+package flag
+
+import (
+ "testing"
+ "time"
+)
+
+func TestFlagSetCustom(t *testing.T) {
+ fs := NewFlagSet("test", ContinueOnError)
+
+ i := fs.Int("i", 0, "i")
+ d := fs.Custom("d", "1s", "d", func(raw string) (interface{}, error) {
+ return time.ParseDuration(raw)
+ })
+ e := fs.Custom("e", "", "e", StringEnum("json", "table"))
+
+ if err := fs.Parse([]string{"-i", "12", "-d", "3m", "-e", "json"}); err != nil {
+ t.Fatalf("got error: %v", err)
+ }
+
+ if *i != 12 {
+ t.Errorf("-i == %v, wanted %v", *i, 12)
+ }
+ if (*d).(time.Duration) != 3*time.Minute {
+ t.Errorf("-d == %v, wanted %v", *d, 3*time.Minute)
+ }
+ if (*e).(string) != "json" {
+ t.Errorf("-e == %v, wanted %v", *e, "json")
+ }
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..97b2ae6
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,7 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+module go.eth.moe/flag
+
+go 1.14
diff --git a/parsefuncs.go b/parsefuncs.go
new file mode 100644
index 0000000..0263a5a
--- /dev/null
+++ b/parsefuncs.go
@@ -0,0 +1,48 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+package flag
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+func RequiredString(raw string) (interface{}, error) {
+ if raw == "" {
+ return nil, errors.New("must not be empty")
+ }
+ return raw, nil
+}
+
+func StringEnum(values ...string) ParseFunc {
+ return func(raw string) (interface{}, error) {
+ for _, value := range values {
+ if value == raw {
+ return raw, nil
+ }
+ }
+ return raw, fmt.Errorf("must be one of %q", values)
+ }
+}
+
+func IntList(raw string) (interface{}, error) {
+ var ints []int
+
+ if raw == "" {
+ return ints, nil
+ }
+
+ for _, raw := range strings.Split(raw, ",") {
+ i, err := strconv.Atoi(raw)
+ if err != nil {
+ return ints, err
+ }
+ ints = append(ints, i)
+ }
+
+ return ints, nil
+}