diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/transport/tcp/dual_stack_test.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 95 |
3 files changed, 94 insertions, 36 deletions
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 560b4904c..a6f25896b 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) { testV6Connect(t, c) } +func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { + c := context.NewWithOpts(t, context.Options{ + EnableV6: true, + MTU: defaultMTU, + }) + defer c.Cleanup() + + // Create a v6 endpoint but don't set the v6-only TCP option. + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 127c19b00..8f5e3a42d 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2690,14 +2690,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + + // Expand netProtos to include v4 and v6 under dual-stack if the caller is + // binding to a wildcard (empty) address, and this is an IPv6 endpoint with + // v6only set to false. + if netProto == header.IPv6ProtocolNumber { + stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) + alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + if alsoBindToV4 { + netProtos = append(netProtos, header.IPv4ProtocolNumber) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 79646fefe..f791f8f13 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -112,6 +112,18 @@ type Headers struct { TCPOpts []byte } +// Options contains options for creating a new test context. +type Options struct { + // EnableV4 indicates whether IPv4 should be enabled. + EnableV4 bool + + // EnableV6 indicates whether IPv4 should be enabled. + EnableV6 bool + + // MTU indicates the maximum transmission unit on the link layer. + MTU uint32 +} + // Context provides an initialized Network stack and a link layer endpoint // for use in TCP tests. type Context struct { @@ -154,10 +166,30 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + return NewWithOpts(t, Options{ + EnableV4: true, + EnableV6: true, + MTU: mtu, }) +} + +// NewWithOpts allocates and initializes a test context containing a new +// stack and a link-layer endpoint with specific options. +func NewWithOpts(t *testing.T, opts Options) *Context { + if opts.MTU == 0 { + panic("MTU must be greater than 0") + } + + stackOpts := stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + if opts.EnableV4 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) + } + if opts.EnableV6 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) + } + s := stack.New(stackOpts) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB @@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - ep := channel.New(1000, mtu, "") + ep := channel.New(1000, opts.MTU, "") wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - opts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) + wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) } opts2 := stack.NICOptions{Name: "nic2"} if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } + var routeTable []tcpip.Route - s.SetRouteTable([]tcpip.Route{ - { + if opts.EnableV4 { + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, - }, - { + }) + } + + if opts.EnableV6 { + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, NIC: 1, - }, - }) + }) + } + + s.SetRouteTable(routeTable) return &Context{ t: t, |