diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 5abc0860..c027036f 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -359,20 +359,20 @@ func checkDefaultValue(args []string, i int) (val string) { 'L': "|", // | is the sentinel for no value since users are unlikely to use it. It's "reserved" in most shells 'X': "0", } - if isFlag(args[i]) && len(args[i]) == 2 && (len(args) == i+1 || args[i+1][0] == '-') { + if isFlag(args[i]) && len(args[i]) == 2 && (len(args) == i+1 || isFlag(args[i+1])) { if v, ok := flags[rune(args[i][1])]; ok { val = v return } } - if args[i] == "-N" && (len(args) == i+1 || args[i+1][0] == '-') { + if args[i] == "-N" && (len(args) == i+1 || isFlag(args[i+1])) { val = "true" } return } func isFlag(arg string) bool { - return arg[0] == '-' + return len(arg) > 0 && arg[0] == '-' } func isListFlag(arg string) bool { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 4554998a..44326c18 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -584,11 +584,16 @@ func TestConvertOsArgs(t *testing.T) { []string{"-X", "-k2"}, []string{"-X", "0", "-k2"}, }, + { + "flag with empty value", + []string{"-S", "server", "-U", "sa", "-d", "", "-Q", "SELECT 1", "-b"}, + []string{"-S", "server", "-U", "sa", "-d", "", "-Q", "SELECT 1", "-b"}, + }, } for _, c := range tests { t.Run(c.name, func(t *testing.T) { actual := convertOsArgs(c.in) - assert.ElementsMatch(t, c.expected, actual, "Incorrect converted args") + assert.Equal(t, c.expected, actual, "Incorrect converted args") }) } }