diff --git a/lib/flagutil/array.go b/lib/flagutil/array.go index 984e3b9794..231cf41c14 100644 --- a/lib/flagutil/array.go +++ b/lib/flagutil/array.go @@ -2,6 +2,8 @@ package flagutil import ( "flag" + "fmt" + "strconv" "strings" ) @@ -12,21 +14,103 @@ func NewArray(name, description string) *Array { return &a } -// Array holds an array of flag values +// Array is a flag that holds an array of values. +// +// It may be set either by specifying multiple flags with the given name +// passed to NewArray or by joining flag values by comma. +// +// The following example sets equivalent flag array with two items (value1, value2): +// +// -foo=value1 -foo=value2 +// -foo=value1,value2 +// +// Flag values may be quoted. For instance, the following arg creates an array of ("a", "b, c") items: +// +// -foo='a,"b, c"' +// type Array []string // String implements flag.Value interface func (a *Array) String() string { - return strings.Join(*a, ",") + aEscaped := make([]string, len(*a)) + for i, v := range *a { + if strings.ContainsAny(v, `", `+"\n") { + v = fmt.Sprintf("%q", v) + } + aEscaped[i] = v + } + return strings.Join(aEscaped, ",") } // Set implements flag.Value interface func (a *Array) Set(value string) error { - values := strings.Split(value, ",") + values := parseArrayValues(value) *a = append(*a, values...) return nil } +func parseArrayValues(s string) []string { + if len(s) == 0 { + return nil + } + var values []string + for { + v, tail := getNextArrayValue(s) + values = append(values, v) + if len(tail) == 0 { + return values + } + if tail[0] == ',' { + tail = tail[1:] + } + s = tail + } +} + +func getNextArrayValue(s string) (string, string) { + if len(s) == 0 { + return "", "" + } + if s[0] != '"' { + // Fast path - unquoted string + n := strings.IndexByte(s, ',') + if n < 0 { + // The last item + return s, "" + } + return s[:n], s[n:] + } + + // Find the end of quoted string + end := 1 + ss := s[1:] + for { + n := strings.IndexByte(ss, '"') + if n < 0 { + // Cannot find trailing quote. Return the whole string till the end. + return s, "" + } + end += n + 1 + // Verify whether the trailing quote is escaped with backslash. + backslashes := 0 + for n > backslashes && ss[n-backslashes-1] == '\\' { + backslashes++ + } + if backslashes&1 == 0 { + // The trailing quote isn't escaped. + break + } + // The trailing quote is escaped. Continue searching for the next quote. + ss = ss[n+1:] + } + v := s[:end] + vUnquoted, err := strconv.Unquote(v) + if err == nil { + v = vUnquoted + } + return v, s[end:] +} + // GetOptionalArg returns optional arg under the given argIdx. func (a *Array) GetOptionalArg(argIdx int) string { x := *a diff --git a/lib/flagutil/array_test.go b/lib/flagutil/array_test.go index 835d9cfa54..a26b6ada3e 100644 --- a/lib/flagutil/array_test.go +++ b/lib/flagutil/array_test.go @@ -3,6 +3,7 @@ package flagutil import ( "flag" "os" + "reflect" "testing" ) @@ -32,3 +33,61 @@ func TestArray(t *testing.T) { } } } + +func TestArraySet(t *testing.T) { + f := func(s string, expectedValues []string) { + t.Helper() + var a Array + a.Set(s) + if !reflect.DeepEqual([]string(a), expectedValues) { + t.Fatalf("unexpected values parsed;\ngot\n%q\nwant\n%q", a, expectedValues) + } + } + f("", nil) + f(`foo`, []string{`foo`}) + f(`foo,b ar,baz`, []string{`foo`, `b ar`, `baz`}) + f(`foo,b\"'ar,"baz,d`, []string{`foo`, `b\"'ar`, `"baz,d`}) + f(`,foo,,ba"r,`, []string{``, `foo`, ``, `ba"r`, ``}) + f(`""`, []string{``}) + f(`"foo,b\nar"`, []string{`foo,b` + "\n" + `ar`}) + f(`"foo","bar",baz`, []string{`foo`, `bar`, `baz`}) + f(`,fo,"\"b, a'\\",,r,`, []string{``, `fo`, `"b, a'\`, ``, `r`, ``}) +} + +func TestArrayGetOptionalArg(t *testing.T) { + f := func(s string, argIdx int, expectedValue string) { + t.Helper() + var a Array + a.Set(s) + v := a.GetOptionalArg(argIdx) + if v != expectedValue { + t.Fatalf("unexpected value; got %q; want %q", v, expectedValue) + } + } + f("", 0, "") + f("", 1, "") + f("foo", 0, "foo") + f("foo", 23, "foo") + f("foo,bar", 0, "foo") + f("foo,bar", 1, "bar") + f("foo,bar", 2, "") +} + +func TestArrayString(t *testing.T) { + f := func(s string) { + t.Helper() + var a Array + a.Set(s) + result := a.String() + if result != s { + t.Fatalf("unexpected string;\ngot\n%s\nwant\n%s", result, s) + } + } + f("") + f("foo") + f("foo,bar") + f(",") + f(",foo,") + f(`", foo","b\"ar",`) + f(`,"\nfoo\\",bar`) +}