diff options
author | Tamir Duberstein <tamird@google.com> | 2018-09-17 20:42:48 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-09-17 20:44:04 -0700 |
commit | d6409b6564d6f908a217709010df2276497b264b (patch) | |
tree | eef606f06646507d0fbfdba60d48ffc3ff46b715 /pkg/tcpip/transport/tcp | |
parent | bb88c187c5457df14fa78e5e6b6f48cbc90fb489 (diff) |
Prevent TCP connect from picking bound ports
PiperOrigin-RevId: 213387851
Change-Id: Icc6850761bc11afd0525f34863acd77584155140
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 194 |
3 files changed, 198 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index c7943f08e..5a77ee232 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -71,6 +71,8 @@ go_test( "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4085585b0..a048cadf8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1000,23 +1000,26 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). sameAddr := e.id.LocalAddress == e.id.RemoteAddress - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.id.RemotePort { return false, nil } + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) { + return false, nil + } - e.id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) - switch err { + id := e.id + id.LocalPort = p + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) { case nil: + e.id = id return true, nil case tcpip.ErrPortInUse: return false, nil default: return false, err } - }) - if err != nil { + }); err != nil { return err } } @@ -1217,7 +1220,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } // Bind binds the endpoint to a specific local port and optionally address. -func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) { +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() @@ -1245,7 +1248,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret } } - // Reserve the port. port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) if err != nil { return err @@ -1257,7 +1259,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (ret // Any failures beyond this point must remove the port registration. defer func() { - if retErr != nil { + if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) e.isPortReserved = false e.effectiveNetProtos = nil diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 871177842..ac21e565b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -28,6 +28,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/ports" "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" @@ -3013,12 +3015,11 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30) } -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) +func makeStack() (*stack.Stack, *tcpip.Error) { + s := stack.New([]string{ + ipv4.ProtocolName, + ipv6.ProtocolName, + }, []string{tcp.ProtocolName}, stack.Options{}) id := loopback.New() if testing.Verbose() { @@ -3026,11 +3027,19 @@ func TestSelfConnect(t *testing.T) { } if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + return nil, err } - if err := s.AddAddress(1, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + for _, ct := range []struct { + number tcpip.NetworkProtocolNumber + address tcpip.Address + }{ + {ipv4.ProtocolNumber, context.StackAddr}, + {ipv6.ProtocolNumber, context.StackV6Addr}, + } { + if err := s.AddAddress(1, ct.number, ct.address); err != nil { + return nil, err + } } s.SetRouteTable([]tcpip.Route{ @@ -3040,8 +3049,27 @@ func TestSelfConnect(t *testing.T) { Gateway: "", NIC: 1, }, + { + Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, }) + return s, nil +} + +func TestSelfConnect(t *testing.T) { + // This test ensures that intentional self-connects work. In particular, + // it checks that if an endpoint binds to say 127.0.0.1:1000 then + // connects to 127.0.0.1:1000, then it will be connected to itself, and + // is able to send and receive data through the same endpoint. + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { @@ -3095,6 +3123,154 @@ func TestSelfConnect(t *testing.T) { } } +func TestConnectAvoidsBoundPorts(t *testing.T) { + addressTypes := func(t *testing.T, network string) []string { + switch network { + case "ipv4": + return []string{"v4"} + case "ipv6": + return []string{"v6"} + case "dual": + return []string{"v6", "mapped"} + default: + t.Fatalf("unknown network: '%s'", network) + } + + panic("unreachable") + } + + address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { + switch addressType { + case "v4": + if isAny { + return "" + } + return context.StackAddr + case "v6": + if isAny { + return "" + } + return context.StackV6Addr + case "mapped": + if isAny { + return context.V4MappedWildcardAddr + } + return context.StackV4MappedAddr + default: + t.Fatalf("unknown address type: '%s'", addressType) + } + + panic("unreachable") + } + // This test ensures that Endpoint.Connect doesn't select already-bound ports. + networks := []string{"ipv4", "ipv6", "dual"} + for _, exhaustedNetwork := range networks { + t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { + for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { + t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { + for _, isAny := range []bool{false, true} { + t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { + for _, candidateNetwork := range networks { + t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { + for _, candidateAddressType := range addressTypes(t, candidateNetwork) { + t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + + var wq waiter.Queue + var eps []tcpip.Endpoint + defer func() { + for _, ep := range eps { + ep.Close() + } + }() + makeEP := func(network string) tcpip.Endpoint { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch network { + case "ipv4": + networkProtocolNumber = ipv4.ProtocolNumber + case "ipv6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatalf("unknown network: '%s'", network) + } + ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + eps = append(eps, ep) + switch network { + case "ipv4": + case "ipv6": + if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + } + case "dual": + if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + } + default: + t.Fatalf("unknown network: '%s'", network) + } + return ep + } + + var v4reserved, v6reserved bool + switch exhaustedAddressType { + case "v4", "mapped": + v4reserved = true + case "v6": + v6reserved = true + // Dual stack sockets bound to v6 any reserve on v4 as + // well. + if isAny { + switch exhaustedNetwork { + case "ipv6": + case "dual": + v4reserved = true + default: + t.Fatalf("unknown address type: '%s'", exhaustedNetwork) + } + } + default: + t.Fatalf("unknown address type: '%s'", exhaustedAddressType) + } + var collides bool + switch candidateAddressType { + case "v4", "mapped": + collides = v4reserved + case "v6": + collides = v6reserved + default: + t.Fatalf("unknown address type: '%s'", candidateAddressType) + } + + for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { + if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}, nil); err != nil { + t.Fatalf("Bind(%d) failed: %v", i, err) + } + } + want := tcpip.ErrConnectStarted + if collides { + want = tcpip.ErrNoPortAvailable + } + if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { + t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want) + } + }) + } + }) + } + }) + } + }) + } + }) + } +} + func TestPathMTUDiscovery(t *testing.T) { // This test verifies the stack retransmits packets after it receives an // ICMP packet indicating that the path MTU has been exceeded. |