diff --git a/utils/ip.go b/utils/ip.go index cca055d..8a8985d 100644 --- a/utils/ip.go +++ b/utils/ip.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "strconv" - "strings" ) var ( @@ -33,21 +32,21 @@ func (ipDesc IPDesc) PortString() string { } func (ipDesc IPDesc) String() string { - return fmt.Sprintf("%s%s", ipDesc.IP, ipDesc.PortString()) + return net.JoinHostPort(ipDesc.IP.String(), fmt.Sprintf("%d", ipDesc.Port)) } // ToIPDesc ... -// TODO: this was kinda hacked together, it should be verified. func ToIPDesc(str string) (IPDesc, error) { - parts := strings.Split(str, ":") - if len(parts) != 2 { + host, portStr, err := net.SplitHostPort(str) + if err != nil { return IPDesc{}, errBadIP } - port, err := strconv.ParseUint(parts[1], 10 /*=base*/, 16 /*=size*/) + port, err := strconv.ParseUint(portStr, 10 /*=base*/, 16 /*=size*/) if err != nil { + // TODO: Should this return a locally defined error? (e.g. errBadPort) return IPDesc{}, err } - ip := net.ParseIP(parts[0]) + ip := net.ParseIP(host) if ip == nil { return IPDesc{}, errBadIP } diff --git a/utils/ip_test.go b/utils/ip_test.go index a950a5a..179014f 100644 --- a/utils/ip_test.go +++ b/utils/ip_test.go @@ -88,7 +88,7 @@ func TestIPDescString(t *testing.T) { result string }{ {IPDesc{net.ParseIP("127.0.0.1"), 0}, "127.0.0.1:0"}, - {IPDesc{net.ParseIP("::1"), 42}, "::1:42"}, + {IPDesc{net.ParseIP("::1"), 42}, "[::1]:42"}, {IPDesc{net.ParseIP("::ffff:127.0.0.1"), 65535}, "127.0.0.1:65535"}, {IPDesc{net.IP{}, 1234}, ":1234"}, } @@ -100,3 +100,52 @@ func TestIPDescString(t *testing.T) { }) } } + +func TestToIPDescError(t *testing.T) { + tests := []struct { + in string + out IPDesc + }{ + {"", IPDesc{}}, + {":", IPDesc{}}, + {"abc:", IPDesc{}}, + {":abc", IPDesc{}}, + {"abc:abc", IPDesc{}}, + {"127.0.0.1:", IPDesc{}}, + {":1", IPDesc{}}, + {"::1", IPDesc{}}, + {"::1:42", IPDesc{}}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + result, err := ToIPDesc(tt.in) + if err == nil { + t.Errorf("Unexpected success") + } + if !tt.out.Equal(result) { + t.Errorf("Expected %v, got %v", tt.out, result) + } + }) + } +} + +func TestToIPDesc(t *testing.T) { + tests := []struct { + in string + out IPDesc + }{ + {"127.0.0.1:42", IPDesc{net.ParseIP("127.0.0.1"), 42}}, + {"[::1]:42", IPDesc{net.ParseIP("::1"), 42}}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + result, err := ToIPDesc(tt.in) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + if !tt.out.Equal(result) { + t.Errorf("Expected %#v, got %#v", tt.out, result) + } + }) + } +}