From 3789c34b22e7a7466149bfbeedf05bf49188130c Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 3 Sep 2019 15:59:58 -0700 Subject: Make UDP traceroute work. Adds support to generate Port Unreachable messages for UDP datagrams received on a port for which there is no valid endpoint. Fixes #703 PiperOrigin-RevId: 267034418 --- pkg/tcpip/transport/tcp/testing/context/context.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'pkg/tcpip/transport/tcp/testing') diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 272481aa0..18c707a57 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -267,7 +267,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -286,9 +286,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) icmp.SetType(typ) icmp.SetCode(code) - - copy(icmp[header.ICMPv4PayloadOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) + const icmpv4VariableHeaderOffset = 4 + copy(icmp[icmpv4VariableHeaderOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) -- cgit v1.2.3 From fe1f5210774d015d653df164d6f676658863780c Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 6 Sep 2019 17:59:46 -0700 Subject: Remove reundant global tcpip.LinkEndpointID. PiperOrigin-RevId: 267709597 --- pkg/tcpip/link/channel/channel.go | 6 +- pkg/tcpip/link/fdbased/endpoint.go | 18 +- pkg/tcpip/link/fdbased/endpoint_test.go | 3 +- pkg/tcpip/link/loopback/loopback.go | 4 +- pkg/tcpip/link/muxed/injectable.go | 5 +- pkg/tcpip/link/muxed/injectable_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 8 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 4 +- pkg/tcpip/link/sniffer/sniffer.go | 18 +- pkg/tcpip/link/waitable/waitable.go | 7 +- pkg/tcpip/link/waitable/waitable_test.go | 6 +- pkg/tcpip/network/arp/arp_test.go | 10 +- pkg/tcpip/network/ipv4/ipv4_test.go | 22 +- pkg/tcpip/network/ipv6/icmp_test.go | 22 +- pkg/tcpip/network/ipv6/ndp_test.go | 15 +- pkg/tcpip/stack/registration.go | 28 --- pkg/tcpip/stack/stack.go | 27 +- pkg/tcpip/stack/stack_test.go | 278 ++++++++++----------- pkg/tcpip/stack/transport_test.go | 24 +- pkg/tcpip/tcpip.go | 3 - pkg/tcpip/transport/tcp/testing/context/context.go | 9 +- pkg/tcpip/transport/udp/udp_test.go | 9 +- runsc/boot/network.go | 18 +- 23 files changed, 250 insertions(+), 298 deletions(-) (limited to 'pkg/tcpip/transport/tcp/testing') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c40744b8e..eec430d0a 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -44,14 +44,12 @@ type Endpoint struct { } // New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { + return &Endpoint{ C: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } - - return stack.RegisterLinkEndpoint(e), e } // Drain removes all outbound packets from the channel and counts them. diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 77f988b9f..adcf21371 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -165,7 +165,7 @@ type Options struct { // // Makes fd non-blocking, but does not take ownership of fd, which must remain // open for the lifetime of the returned endpoint. -func New(opts *Options) (tcpip.LinkEndpointID, error) { +func New(opts *Options) (stack.LinkEndpoint, error) { caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -190,7 +190,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } if len(opts.FDs) == 0 { - return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } e := &endpoint{ @@ -207,12 +207,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { for i := 0; i < len(e.fds); i++ { fd := e.fds[i] if err := syscall.SetNonblock(fd, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) + return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } isSocket, err := isSocketFD(fd) if err != nil { - return 0, err + return nil, err } if isSocket { if opts.GSOMaxSize != 0 { @@ -222,12 +222,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) } e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { @@ -435,14 +435,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf } // NewInjectable creates a new fd-based InjectableEndpoint. -func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) { +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint { syscall.SetNonblock(fd, true) - e := &InjectableEndpoint{endpoint: endpoint{ + return &InjectableEndpoint{endpoint: endpoint{ fds: []int{fd}, mtu: mtu, caps: capabilities, }} - - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e305252d6..04406bc9a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context { } opt.FDs = []int{fds[1]} - epID, err := New(opt) + ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } - ep := stack.FindLinkEndpoint(epID).(*endpoint) c := &context{ t: t, diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ab6a53988..e121ea1a5 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -32,8 +32,8 @@ type endpoint struct { // New creates a new loopback endpoint. This link-layer endpoint just turns // outbound packets into inbound packets. -func New() tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{}) +func New() stack.LinkEndpoint { + return &endpoint{} } // Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a577a3d52..3ed7b98d1 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -105,9 +105,8 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) * } // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ routes: routes, } - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 174b9330f..3086fec00 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc if err != nil { t.Fatal("Failed to create socket pair:", err) } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) + underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) + endpoint := NewInjectableEndpoint(routes) return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 834ea5c40..ba387af73 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -94,7 +94,7 @@ type endpoint struct { // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { e := &endpoint{ mtu: mtu, bufferSize: bufferSize, @@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc } if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err + return nil, err } if err := e.rx.init(bufferSize, &rx); err != nil { e.tx.cleanup() - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } // Close frees all resources associated with the endpoint. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 98036f367..0e9ba0846 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) if err != nil { t.Fatalf("New failed: %v", err) } - c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep = ep.(*endpoint) c.ep.Attach(c) return c diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 36c8c46fc..e7b6d7912 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -58,10 +58,10 @@ type endpoint struct { // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. -func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), - }) +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + return &endpoint{ + lower: lower, + } } func zoneOffset() (int32, error) { @@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal too snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { +func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { if err := writePCAPHeader(file, snapLen); err != nil { - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), + return &endpoint{ + lower: lower, file: file, maxPCAPLen: snapLen, - }), nil + }, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 3b6ac2ff7..408cc62f7 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -40,11 +40,10 @@ type Endpoint struct { // New creates a new waitable link-layer endpoint. It wraps around another // endpoint and allows the caller to block new write/dispatch calls and wait for // the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, } - return stack.RegisterLinkEndpoint(e), e } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 56e18ecb0..1031438b1 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -72,7 +72,7 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Write and check that it goes through. wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) @@ -97,7 +97,7 @@ func TestWaitWrite(t *testing.T) { func TestWaitDispatch(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Check that attach happens. wep.Attach(ep) @@ -139,7 +139,7 @@ func TestOtherMethods(t *testing.T) { hdrLen: hdrLen, linkAddr: linkAddr, } - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) if v := wep.MTU(); v != mtu { t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..387fca96e 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -47,11 +47,13 @@ func newTestContext(t *testing.T) *testContext { s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +75,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..ae827ca27 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,11 +36,11 @@ func TestExcludeBroadcast(t *testing.T) { s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -184,15 +184,12 @@ type errorChannel struct { // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), Ch: make(chan packetInfo, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec } // packetInfo holds all the information about an outbound packet. @@ -242,9 +239,8 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +262,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 227a65cf2..a6a1a5232 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -83,8 +83,7 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li func TestICMPCounts(t *testing.T) { s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -211,14 +210,13 @@ func newTestContext(t *testing.T) *testContext { } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -228,11 +226,9 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress sn lladdr0: %v", err) } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8e4cf0e74..571915d3f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -32,15 +32,14 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Helper() s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) } + { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 67b70b2ee..88a698b18 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -379,10 +377,6 @@ var ( networkProtocols = make(map[string]NetworkProtocolFactory) unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) ) // RegisterTransportProtocolFactory registers a new transport protocol factory @@ -406,28 +400,6 @@ func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { unassociatedFactory = f } -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6beca6ae8..a961e8ebe 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -620,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -645,33 +640,33 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c6a8160af..0c26c9911 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -189,14 +189,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -225,9 +225,9 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -245,7 +245,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -255,7 +255,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -265,7 +265,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -274,7 +274,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -284,7 +284,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -307,59 +307,59 @@ func send(r stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) } -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := sendTo(s, addr, payload); err != nil { t.Error("sendTo failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("sendTo packet count: got = %d, want %d", got, want) } } -func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := send(r, payload); err != nil { t.Error("send failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("send packet count: got = %d, want %d", got, want) } } -func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) } } -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we expect to receive it. want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we do NOT expect to receive it. want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -369,9 +369,9 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -388,7 +388,7 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", linkEP, nil) + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -397,8 +397,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -410,8 +410,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -442,10 +442,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - testSendTo(t, s, "\x05", linkEP1, nil) + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - testSendTo(t, s, "\x06", linkEP2, nil) + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -478,8 +478,8 @@ func TestRoutes(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -491,8 +491,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -556,8 +556,8 @@ func TestAddressRemoval(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -578,15 +578,15 @@ func TestAddressRemoval(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -601,9 +601,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) @@ -626,17 +626,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -690,8 +690,8 @@ func TestEndpointExpiration(t *testing.T) { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -724,15 +724,15 @@ func TestEndpointExpiration(t *testing.T) { //----------------------- verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 2. Add Address, everything should work. @@ -741,8 +741,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. @@ -752,15 +752,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 4. Add Address back, everything should work again. @@ -769,8 +769,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 5. Take a reference to the endpoint by getting a route. Verify that // we can still send/receive, including sending using the route. @@ -779,9 +779,9 @@ func TestEndpointExpiration(t *testing.T) { if err != nil { t.Fatal("FindRoute failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -791,16 +791,16 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 7. Add Address back, everything should work again. @@ -809,16 +809,16 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -828,15 +828,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } }) } @@ -846,8 +846,8 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -867,13 +867,13 @@ func TestPromiscuousMode(t *testing.T) { // have a matching endpoint. const localAddrByte byte = 0x01 buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -886,7 +886,7 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestSpoofingWithAddress(t *testing.T) { @@ -896,8 +896,8 @@ func TestSpoofingWithAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -936,8 +936,8 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet works. - testSendTo(t, s, dstAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) // FindRoute should also work with a local address that exists on the NIC. r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) @@ -951,7 +951,7 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. - testSend(t, r, linkEP, nil) + testSend(t, r, ep, nil) } func TestSpoofingNoAddress(t *testing.T) { @@ -960,8 +960,8 @@ func TestSpoofingNoAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -980,7 +980,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -999,14 +999,14 @@ func TestSpoofingNoAddress(t *testing.T) { } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -1076,8 +1076,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1132,8 +1132,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1159,7 +1159,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddAddressRange failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { @@ -1198,8 +1198,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1236,8 +1236,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,7 +1262,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddAddressRange failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { @@ -1320,8 +1320,8 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S func TestAddresRangeAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1361,8 +1361,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert primary and never-primary addresses. @@ -1426,8 +1426,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1501,8 +1501,8 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1526,8 +1526,8 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1558,8 +1558,8 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1587,8 +1587,8 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1621,8 +1621,8 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { @@ -1639,7 +1639,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1653,9 +1653,9 @@ func TestNICStats(t *testing.T) { if err := sendTo(s, "\x01", payload); err != nil { t.Fatal("sendTo failed: ", err) } - want := uint64(linkEP1.Drain()) + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1669,16 +1669,16 @@ func TestNICForwarding(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1697,10 +1697,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ca185279e..87d1e0d0d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -278,9 +278,9 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +340,9 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +408,9 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -497,16 +497,16 @@ func TestTransportForwarding(t *testing.T) { s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +545,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +559,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 418e771d2..ebf8a2d04 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -600,9 +600,6 @@ func (r Route) String() string { return out.String() } -// LinkEndpointID represents a data link layer endpoint. -type LinkEndpointID uint64 - // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 18c707a57..16783e716 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -150,11 +150,12 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -180,7 +181,7 @@ func New(t *testing.T, mtu uint32) *Context { return &Context{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 995d6e8a1..c6deab892 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -275,12 +275,13 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) - id, linkEP := channel.New(256, mtu, "") if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -306,7 +307,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/runsc/boot/network.go b/runsc/boot/network.go index ea0d9f790..32cba5ac1 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -121,10 +121,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct nicID++ nicids[link.Name] = nicID - linkEP := loopback.New() + ep := loopback.New() log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil { return err } @@ -156,7 +156,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } mac := tcpip.LinkAddress(link.LinkAddress) - linkEP, err := fdbased.New(&fdbased.Options{ + ep, err := fdbased.New(&fdbased.Options{ FDs: FDs, MTU: uint32(link.MTU), EthernetHeader: true, @@ -170,7 +170,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil { return err } @@ -203,14 +203,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct // createNICWithAddrs creates a NIC in the network stack and adds the given // addresses. -func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, addrs []net.IP, loopback bool) error { +func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error { if loopback { - if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(linkEP)); err != nil { - return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, linkEP, err) + if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err) } } else { - if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(linkEP)); err != nil { - return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, linkEP, err) + if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err) } } -- cgit v1.2.3 From 03ee55cc62c99c5b8f5d6fb00423a66ef44589e3 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Mon, 23 Sep 2019 14:37:39 -0700 Subject: netstack: convert more socket options to {Set,Get}SockOptInt PiperOrigin-RevId: 270763208 --- pkg/sentry/socket/epsocket/epsocket.go | 22 ++-- pkg/sentry/socket/unix/transport/unix.go | 82 ++++++------ pkg/tcpip/stack/transport_test.go | 5 + pkg/tcpip/tcpip.go | 32 +++-- pkg/tcpip/transport/icmp/endpoint.go | 29 +++-- pkg/tcpip/transport/raw/endpoint.go | 30 +++-- pkg/tcpip/transport/tcp/endpoint.go | 144 +++++++++++---------- pkg/tcpip/transport/tcp/tcp_noracedetector_test.go | 10 +- pkg/tcpip/transport/tcp/tcp_test.go | 121 ++++++++--------- pkg/tcpip/transport/tcp/testing/context/context.go | 8 +- pkg/tcpip/transport/udp/endpoint.go | 32 +++-- 11 files changed, 276 insertions(+), 239 deletions(-) (limited to 'pkg/tcpip/transport/tcp/testing') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 25adca090..3e66f9cbb 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -209,6 +209,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -887,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -903,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1275,7 +1279,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1283,7 +1287,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -2317,9 +2321,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2b0ad6395..1867b3a5c 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 847d02982..0e69ac7c8 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2534069ab..c021c67ac 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -401,6 +401,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptInt sets a socket option, for simple cases where a value + // has the int type. + SetSockOptInt(opt SockOpt, v int) *Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error @@ -446,10 +450,22 @@ type WriteOptions struct { type SockOpt int const ( - // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of - // unread bytes in the input buffer should be returned. + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption SockOpt = iota + // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the send buffer size option. + SendBufferSizeOption + + // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the receive buffer size option. + ReceiveBufferSizeOption + + // SendQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the output buffer should be returned. + SendQueueSizeOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -458,18 +474,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send -// buffer size option. -type SendBufferSizeOption int - -// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the -// receive buffer size option. -type ReceiveBufferSizeOption int - -// SendQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the output buffer should be returned. -type SendQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3db060384..a111fdb2a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index cf1c5c433..a02731a5d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -492,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -504,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } ep.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSize + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption @@ -515,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - ep.mu.Lock() - *o = tcpip.SendBufferSizeOption(ep.sndBufSize) - ep.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index dd931f88c..35b489c68 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -952,62 +952,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil - - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1071,6 +1018,67 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Unlock() return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -1182,6 +1190,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -1204,18 +1224,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.sndBufMu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) - e.rcvListMu.Unlock() - return nil - case *tcpip.DelayOption: *o = 0 if v := atomic.LoadUint32(&e.delay); v != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 272bbcdbd..9fa97528b 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { enableCUBIC(t, c) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 32bb45224..7fa5cfb6e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) @@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -469,7 +469,7 @@ func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -557,8 +557,7 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -631,7 +630,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -700,7 +699,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -785,7 +784,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -804,8 +803,7 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -872,11 +870,9 @@ func TestNoWindowShrinking(t *testing.T) { defer c.Cleanup() // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { t.Fatalf("SetSockOpt failed: %v", err) } @@ -976,7 +972,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1017,7 +1013,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, nil) + c.CreateConnected(789, 0, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1075,8 +1071,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1110,8 +1105,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1151,7 +1145,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1224,7 +1218,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1293,8 +1287,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the window size such that a window scale of 4 will be used. const wnd = 65535 * 10 const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1399,7 +1392,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Prevent the endpoint from processing packets. test.stop(c.EP) @@ -1449,7 +1442,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1497,7 +1490,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1579,7 +1572,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -1695,7 +1688,7 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } @@ -1704,7 +1697,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -1727,7 +1720,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1871,7 +1864,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1973,7 +1966,7 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2010,7 +2003,7 @@ func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2035,7 +2028,7 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2078,7 +2071,7 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2132,7 +2125,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := buffer.NewView(10) @@ -2203,7 +2196,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -2291,7 +2284,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) @@ -2377,7 +2370,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -2509,7 +2502,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) @@ -2559,7 +2552,7 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 @@ -2580,7 +2573,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 vv := c.BuildSegment(nil, &context.Headers{ @@ -2604,7 +2597,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ @@ -2635,7 +2628,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send 200 segments. data := []byte{1, 2, 3} @@ -2681,7 +2674,7 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -2856,8 +2849,8 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2869,8 +2862,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2945,26 +2938,26 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -3231,7 +3224,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -3308,7 +3301,7 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3482,7 +3475,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 16783e716..78eff5c3a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -512,7 +512,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { } // CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } @@ -590,7 +590,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -598,8 +598,8 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("NewEndpoint failed: %v", err) } - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6ac7c067a..0bec7e62d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -389,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: @@ -568,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -578,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { -- cgit v1.2.3 From 59ccbb10446063f5347fb026e35549bc2f677971 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 25 Sep 2019 12:56:00 -0700 Subject: Remove centralized registration of protocols. Also removes the need for protocol names. PiperOrigin-RevId: 271186030 --- pkg/tcpip/adapters/gonet/gonet_test.go | 5 +- pkg/tcpip/network/arp/arp.go | 16 ++- pkg/tcpip/network/arp/arp_test.go | 5 +- pkg/tcpip/network/ip_test.go | 10 +- pkg/tcpip/network/ipv4/ipv4.go | 44 +++----- pkg/tcpip/network/ipv4/ipv4_test.go | 9 +- pkg/tcpip/network/ipv6/icmp_test.go | 15 ++- pkg/tcpip/network/ipv6/ipv6.go | 24 ++--- pkg/tcpip/network/ipv6/ipv6_test.go | 34 ++++--- pkg/tcpip/network/ipv6/ndp_test.go | 5 +- pkg/tcpip/sample/tun_tcp_connect/main.go | 5 +- pkg/tcpip/sample/tun_tcp_echo/main.go | 5 +- pkg/tcpip/stack/registration.go | 36 ------- pkg/tcpip/stack/stack.go | 45 ++++----- pkg/tcpip/stack/stack_test.go | 111 +++++++++++++++------ pkg/tcpip/stack/transport_test.go | 35 +++++-- pkg/tcpip/transport/icmp/protocol.go | 27 ++--- pkg/tcpip/transport/raw/protocol.go | 9 +- pkg/tcpip/transport/tcp/protocol.go | 22 ++-- pkg/tcpip/transport/tcp/tcp_test.go | 21 ++-- pkg/tcpip/transport/tcp/testing/context/context.go | 5 +- pkg/tcpip/transport/udp/protocol.go | 12 +-- pkg/tcpip/transport/udp/udp_test.go | 5 +- runsc/boot/BUILD | 1 + runsc/boot/loader.go | 17 ++-- 25 files changed, 278 insertions(+), 245 deletions(-) (limited to 'pkg/tcpip/transport/tcp/testing') diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 672f026b2..8ced960bb 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { return nil, err diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index fd6395fc1..26cf1c528 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -16,9 +16,9 @@ // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. // -// To use it in the networking stack, pass arp.ProtocolName as one of the -// network protocols when calling stack.New. Then add an "arp" address to -// every NIC on the stack that should respond to ARP requests. That is: +// To use it in the networking stack, pass arp.NewProtocol() as one of the +// network protocols when calling stack.New. Then add an "arp" address to every +// NIC on the stack that should respond to ARP requests. That is: // // if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { // // handle err @@ -33,9 +33,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ARP protocol name. - ProtocolName = "arp" - // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber @@ -200,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an ARP network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 387fca96e..88b57ec03 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -44,7 +44,10 @@ type testContext struct { } func newTestContext(t *testing.T) *testContext { - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + }) const defaultMTU = 65536 ep := channel.New(256, defaultMTU, stackLinkAddr) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6a40e7ee3..a9741622e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen } func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -185,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7a06f525..b7b07a6c1 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -14,9 +14,9 @@ // Package ipv4 contains the implementation of the ipv4 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv4.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv4.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv4.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv4 @@ -32,9 +32,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv4 protocol name. - ProtocolName = "ipv4" - // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -53,6 +50,7 @@ type endpoint struct { linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation + protocol *protocol } // NewEndpoint creates a new ipv4 endpoint. @@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, } return e, nil @@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect if payload.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) } ip.SetID(uint16(id)) } @@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() {} -type protocol struct{} - -// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} +type protocol struct { + ids []uint32 + hashIV uint32 } // Number returns the ipv4 protocol number. @@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 { // hashRoute calculates a hash value for the given route. It uses the source & // destination address, the transport protocol number, and a random initial // value (generated once on initialization) to generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { t := r.LocalAddress a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = r.RemoteAddress @@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -var ( - ids []uint32 - hashIV uint32 -) - -func init() { - ids = make([]uint32, buckets) +// NewProtocol returns an IPv4 network protocol. +func NewProtocol() stack.NetworkProtocol { + ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. r := hash.RandN32(1 + buckets) for i := range ids { ids[i] = r[i] } - hashIV = r[buckets] + hashIV := r[buckets] - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) + return &protocol{ids: ids, hashIV: hashIV} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ae827ca27..b6641ccc3 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -33,7 +33,10 @@ import ( ) func TestExcludeBroadcast(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) @@ -238,7 +241,9 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. - s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) s.CreateNIC(1, ep) const ( diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 653d984e9..01f5a17ec 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -81,7 +81,10 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li } func TestICMPCounts(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) { if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -205,8 +208,14 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), + s0: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + s1: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), } const defaultMTU = 65536 diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331a8bdaa..7de6a4546 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -14,9 +14,9 @@ // Package ipv6 contains the implementation of the ipv6 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv6.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv6.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv6.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv6 @@ -28,9 +28,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv6 protocol name. - ProtocolName = "ipv6" - // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -160,14 +157,6 @@ func (*endpoint) Close() {} type protocol struct{} -// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} -} - // Number returns the ipv6 protocol number. func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber @@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an IPv6 network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 57bcd5455..78c674c2c 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -124,17 +124,20 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst // UDP packets destined to the IPv6 link-local all-nodes multicast address. func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { - name string - protocolName string - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.ProtocolName6, testReceiveICMP}, - {"UDP", udp.ProtocolName, testReceiveUDP}, + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -152,19 +155,22 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { tests := []struct { - name string - protocolName string - rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.ProtocolName6, testReceiveICMP}, - {"UDP", udp.ProtocolName, testReceiveUDP}, + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -237,7 +243,9 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New([]string{ProtocolName}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 571915d3f..e30791fe3 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -31,7 +31,10 @@ import ( func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index f12189580..2239c1e66 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -126,7 +126,10 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 329941775..bca73cbb1 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -111,7 +111,10 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) mtu, err := rawfile.GetMTU(tunName) if err != nil { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 07e4c770d..80101d4bb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -366,14 +366,6 @@ type LinkAddressCache interface { RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// TransportProtocolFactory functions are used by the stack to instantiate -// transport protocols. -type TransportProtocolFactory func() TransportProtocol - -// NetworkProtocolFactory provides methods to be used by the stack to -// instantiate network protocols. -type NetworkProtocolFactory func() NetworkProtocol - // UnassociatedEndpointFactory produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can be used // to write arbitrary packets that include the IP header. @@ -381,34 +373,6 @@ type UnassociatedEndpointFactory interface { NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } -var ( - transportProtocols = make(map[string]TransportProtocolFactory) - networkProtocols = make(map[string]NetworkProtocolFactory) - - unassociatedFactory UnassociatedEndpointFactory -) - -// RegisterTransportProtocolFactory registers a new transport protocol factory -// with the stack so that it becomes available to users of the stack. This -// function is intended to be called by init() functions of the protocols. -func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { - transportProtocols[name] = p -} - -// RegisterNetworkProtocolFactory registers a new network protocol factory with -// the stack so that it becomes available to users of the stack. This function -// is intended to be called by init() functions of the protocols. -func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { - networkProtocols[name] = p -} - -// RegisterUnassociatedFactory registers a factory to produce endpoints not -// associated with any particular transport protocol. This function is intended -// to be called by init() functions of the protocols. -func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { - unassociatedFactory = f -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f7ba3cb0f..18d1704a5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -17,11 +17,6 @@ // // For consumers, the only function of interest is New(), everything else is // provided by the tcpip/public package. -// -// For protocol implementers, RegisterTransportProtocolFactory() and -// RegisterNetworkProtocolFactory() are used to register protocol factories with -// the stack, which will then be used to instantiate protocol objects when -// consumers interact with the stack. package stack import ( @@ -351,6 +346,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // unassociatedFactory creates unassociated endpoints. If nil, raw + // endpoints are disabled. It is set during Stack creation and is + // immutable. unassociatedFactory UnassociatedEndpointFactory demux *transportDemuxer @@ -359,10 +357,6 @@ type Stack struct { linkAddrCache *linkAddrCache - // raw indicates whether raw sockets may be created. It is set during - // Stack creation and is immutable. - raw bool - mu sync.RWMutex nics map[tcpip.NICID]*NIC forwarding bool @@ -398,6 +392,12 @@ type Stack struct { // Options contains optional Stack configuration. type Options struct { + // NetworkProtocols lists the network protocols to enable. + NetworkProtocols []NetworkProtocol + + // TransportProtocols lists the transport protocols to enable. + TransportProtocols []TransportProtocol + // Clock is an optional clock source used for timestampping packets. // // If no Clock is specified, the clock source will be time.Now. @@ -411,8 +411,9 @@ type Options struct { // stack (false). HandleLocal bool - // Raw indicates whether raw sockets may be created. - Raw bool + // UnassociatedFactory produces unassociated endpoints raw endpoints. + // Raw endpoints are enabled only if this is non-nil. + UnassociatedFactory UnassociatedEndpointFactory } // New allocates a new networking stack with only the requested networking and @@ -422,7 +423,7 @@ type Options struct { // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options // are supported. -func New(network []string, transport []string, opts Options) *Stack { +func New(opts Options) *Stack { clock := opts.Clock if clock == nil { clock = &tcpip.StdClock{} @@ -438,17 +439,11 @@ func New(network []string, transport []string, opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - raw: opts.Raw, icmpRateLimiter: NewICMPRateLimiter(), } // Add specified network protocols. - for _, name := range network { - netProtoFactory, ok := networkProtocols[name] - if !ok { - continue - } - netProto := netProtoFactory() + for _, netProto := range opts.NetworkProtocols { s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -456,18 +451,14 @@ func New(network []string, transport []string, opts Options) *Stack { } // Add specified transport protocols. - for _, name := range transport { - transProtoFactory, ok := transportProtocols[name] - if !ok { - continue - } - transProto := transProtoFactory() + for _, transProto := range opts.TransportProtocols { s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } } - s.unassociatedFactory = unassociatedFactory + // Add the factory for unassociated endpoints, if present. + s.unassociatedFactory = opts.UnassociatedFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -602,7 +593,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if !s.raw { + if s.unassociatedFactory == nil { return nil, tcpip.ErrNotPermitted } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7aa10bce9..d2dede8a9 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -222,11 +222,17 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeNetFactory() stack.NetworkProtocol { + return &fakeNetworkProtocol{} +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -370,7 +376,9 @@ func TestNetworkSend(t *testing.T) { // address: 1. The route table sends all packets through the only // existing nic. ep := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -395,7 +403,9 @@ func TestNetworkSendMultiRoute(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -476,7 +486,9 @@ func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -554,7 +566,9 @@ func TestAddressRemoval(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -599,7 +613,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -688,7 +704,9 @@ func TestEndpointExpiration(t *testing.T) { for _, promiscuous := range []bool{true, false} { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { @@ -844,7 +862,9 @@ func TestEndpointExpiration(t *testing.T) { } func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -894,7 +914,9 @@ func TestSpoofingWithAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x02") dstAddr := tcpip.Address("\x03") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -958,7 +980,9 @@ func TestSpoofingNoAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1003,7 +1027,9 @@ func TestSpoofingNoAddress(t *testing.T) { } func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1074,7 +1100,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } { t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1130,7 +1158,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { // Add a range of addresses, then check that a packet is delivered. func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1196,7 +1226,9 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub // existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1234,7 +1266,9 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { // Set a range of addresses, then send a packet to a destination outside the // range and then check it doesn't get delivered. func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1266,7 +1300,10 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { } func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{}, + }) // Try an unsupported network protocol. if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -1319,7 +1356,9 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S } func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1360,7 +1399,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1425,7 +1466,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1508,7 +1551,9 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1533,7 +1578,9 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1565,7 +1612,9 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1594,7 +1643,9 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -1628,7 +1679,9 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) @@ -1674,7 +1727,9 @@ func TestNICStats(t *testing.T) { func TestNICForwarding(t *testing.T) { // Create a stack with the fake network protocol, two NICs, each with // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) s.SetForwarding(true) ep1 := channel.New(10, defaultMTU, "") @@ -1722,9 +1777,3 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } - -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 0e69ac7c8..56e8a5d9b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -282,9 +282,16 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeTransFactory() stack.TransportProtocol { + return &fakeTransportProtocol{} +} + func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -346,7 +353,10 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -414,7 +424,10 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -457,7 +470,10 @@ func TestTransportSend(t *testing.T) { } func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) // Try an unsupported transport protocol. if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -498,7 +514,10 @@ func TestTransportOptions(t *testing.T) { } func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. @@ -576,9 +595,3 @@ func TestTransportForwarding(t *testing.T) { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 1eb790932..bfb16f7c3 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -14,10 +14,9 @@ // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport // protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and -// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or -// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when -// calling stack.New(). Then endpoints can be created by passing +// must be added to the project, and activated on the stack by passing +// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport +// protocols when calling stack.New(). Then endpoints can be created by passing // icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number // when calling Stack.NewEndpoint(). package icmp @@ -34,15 +33,9 @@ import ( ) const ( - // ProtocolName4 is the string representation of the icmp protocol name. - ProtocolName4 = "icmp4" - // ProtocolNumber4 is the ICMP protocol number. ProtocolNumber4 = header.ICMPv4ProtocolNumber - // ProtocolName6 is the string representation of the icmp protocol name. - ProtocolName6 = "icmp6" - // ProtocolNumber6 is the IPv6-ICMP protocol number. ProtocolNumber6 = header.ICMPv6ProtocolNumber ) @@ -125,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol { - return &protocol{ProtocolNumber4} - }) +// NewProtocol4 returns an ICMPv4 transport protocol. +func NewProtocol4() stack.TransportProtocol { + return &protocol{ProtocolNumber4} +} - stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol { - return &protocol{ProtocolNumber6} - }) +// NewProtocol6 returns an ICMPv6 transport protocol. +func NewProtocol6() stack.TransportProtocol { + return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index 783c21e6b..a2512d666 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -20,13 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -type factory struct{} +// EndpointFactory implements stack.UnassociatedEndpointFactory. +type EndpointFactory struct{} // NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } - -func init() { - stack.RegisterUnassociatedFactory(factory{}) -} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a13b2022..d5d8ab96a 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -14,7 +14,7 @@ // Package tcp contains the implementation of the TCP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// activated on the stack by passing tcp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing tcp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -34,9 +34,6 @@ import ( ) const ( - // ProtocolName is the string representation of the tcp protocol name. - ProtocolName = "tcp" - // ProtocolNumber is the tcp protocol number. ProtocolNumber = header.TCPProtocolNumber @@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{ - sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, - recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, - congestionControl: ccReno, - availableCongestionControl: []string{ccReno, ccCubic}, - } - }) +// NewProtocol returns a TCP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7fa5cfb6e..2be094876 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2873,7 +2873,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } func TestDefaultBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2919,7 +2922,10 @@ func TestDefaultBufferSizes(t *testing.T) { } func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Check the default values. ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) @@ -2965,10 +2971,13 @@ func TestMinMaxBufferSizes(t *testing.T) { } func makeStack() (*stack.Stack, *tcpip.Error) { - s := stack.New([]string{ - ipv4.ProtocolName, - ipv6.ProtocolName, - }, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + ipv6.NewProtocol(), + }, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) id := loopback.New() if testing.Verbose() { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 78eff5c3a..d3f1d2cdf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -137,7 +137,10 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) // Allow minimum send/receive buffer sizes to be 1 during tests. if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 068d9a272..f5cc932dd 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -14,7 +14,7 @@ // Package udp contains the implementation of the UDP transport protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// activated on the stack by passing udp.NewProtocol() as one of the // transport protocols when calling stack.New(). Then endpoints can be created // by passing udp.ProtocolNumber as the transport protocol number when calling // Stack.NewEndpoint(). @@ -30,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the udp protocol name. - ProtocolName = "udp" - // ProtocolNumber is the udp protocol number. ProtocolNumber = header.UDPProtocolNumber ) @@ -182,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{} - }) +// NewProtocol returns a UDP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c6deab892..2ec27be4d 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -274,7 +274,10 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 54d1ab129..d90381c0f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -80,6 +80,7 @@ go_library( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/urpc", diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index d824d7dc5..adf345490 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -54,6 +54,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/runsc/boot/filter" @@ -911,15 +912,17 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { case NetworkNone, NetworkSandbox: // NetworkNone sets up loopback using netstack. - netProtos := []string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName} - protoNames := []string{tcp.ProtocolName, udp.ProtocolName, icmp.ProtocolName4} - s := epsocket.Stack{stack.New(netProtos, protoNames, stack.Options{ - Clock: clock, - Stats: epsocket.Metrics, - HandleLocal: true, + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + s := epsocket.Stack{stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + Clock: clock, + Stats: epsocket.Metrics, + HandleLocal: true, // Enable raw sockets for users with sufficient // privileges. - Raw: true, + UnassociatedFactory: raw.EndpointFactory{}, })} // Enable SACK Recovery. -- cgit v1.2.3 From abbee5615f4480d8a41b4cf63839d2ab13b19abf Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 27 Sep 2019 14:12:35 -0700 Subject: Implement SO_BINDTODEVICE sockopt PiperOrigin-RevId: 271644926 --- pkg/sentry/socket/epsocket/epsocket.go | 20 ++ pkg/sentry/syscalls/linux/sys_socket.go | 2 +- pkg/tcpip/ports/ports.go | 114 ++++-- pkg/tcpip/ports/ports_test.go | 113 +++++- pkg/tcpip/stack/BUILD | 4 + pkg/tcpip/stack/nic.go | 15 +- pkg/tcpip/stack/stack.go | 58 +--- pkg/tcpip/stack/transport_demuxer.go | 227 +++++++----- pkg/tcpip/stack/transport_demuxer_test.go | 352 +++++++++++++++++++ pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/icmp/endpoint.go | 6 +- pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 56 ++- pkg/tcpip/transport/tcp/tcp_test.go | 116 +++++++ pkg/tcpip/transport/tcp/testing/context/context.go | 26 +- pkg/tcpip/transport/udp/BUILD | 1 + pkg/tcpip/transport/udp/endpoint.go | 42 ++- pkg/tcpip/transport/udp/forwarder.go | 2 +- pkg/tcpip/transport/udp/udp_test.go | 120 +++---- test/syscalls/linux/BUILD | 75 ++++ test/syscalls/linux/socket_bind_to_device.cc | 314 +++++++++++++++++ .../linux/socket_bind_to_device_distribution.cc | 381 +++++++++++++++++++++ .../linux/socket_bind_to_device_sequence.cc | 316 +++++++++++++++++ test/syscalls/linux/socket_bind_to_device_util.cc | 75 ++++ test/syscalls/linux/socket_bind_to_device_util.h | 67 ++++ test/syscalls/linux/uidgid.cc | 25 +- test/util/BUILD | 11 + test/util/uid_util.cc | 44 +++ test/util/uid_util.h | 29 ++ 30 files changed, 2308 insertions(+), 314 deletions(-) create mode 100644 pkg/tcpip/stack/transport_demuxer_test.go create mode 100644 test/syscalls/linux/socket_bind_to_device.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_distribution.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_sequence.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_util.cc create mode 100644 test/syscalls/linux/socket_bind_to_device_util.h create mode 100644 test/util/uid_util.cc create mode 100644 test/util/uid_util.h (limited to 'pkg/tcpip/transport/tcp/testing') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 3e66f9cbb..5812085fa 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -942,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1305,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 3bac4d90d..b5a72ce63 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - if optLen <= 0 { + if optLen < 0 { return 0, nil, syserror.EINVAL } if optLen > maxOptLen { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 315780c0c..40e202717 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -47,43 +47,76 @@ type portNode struct { refs int } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]portNode - -// isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { - if addr == anyIPAddress { - if len(b) == 0 { - return true - } +// deviceNode is never empty. When it has no elements, it is removed from the +// map that references it. +type deviceNode map[tcpip.NICID]portNode + +// isAvailable checks whether binding is possible by device. If not binding to a +// device, check against all portNodes. If binding to a specific device, check +// against the unspecified device and the provided device. +func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { + if bindToDevice == 0 { + // Trying to binding all devices. if !reuse { + // Can't bind because the (addr,port) is already bound. return false } - for _, n := range b { - if !n.reuse { + for _, p := range d { + if !p.reuse { + // Can't bind because the (addr,port) was previously bound without reuse. return false } } return true } - // If all addresses for this portDescriptor are already bound, no - // address is available. - if n, ok := b[anyIPAddress]; ok { - if !reuse { + if p, ok := d[0]; ok { + if !reuse || !p.reuse { return false } - if !n.reuse { + } + + if p, ok := d[bindToDevice]; ok { + if !reuse || !p.reuse { return false } } - if n, ok := b[addr]; ok { - if !reuse { + return true +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]deviceNode + +// isAvailable checks whether an IP address is available to bind to. If the +// address is the "any" address, check all other addresses. Otherwise, just +// check against the "any" address and the provided address. +func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { + if addr == anyIPAddress { + // If binding to the "any" address then check that there are no conflicts + // with all addresses. + for _, d := range b { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true + } + + // Check that there is no conflict with the "any" address. + if d, ok := b[anyIPAddress]; ok { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + + // Check that this is no conflict with the provided address. + if d, ok := b[addr]; ok { + if !d.isAvailable(reuse, bindToDevice) { return false } - return n.reuse } + return true } @@ -116,17 +149,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse) + return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse) { + if !addrs.isAvailable(addr, reuse, bindToDevice) { return false } } @@ -138,14 +171,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { + if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -153,13 +186,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil + return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { return false } @@ -171,11 +204,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - if n, ok := m[addr]; ok { + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + if n, ok := d[bindToDevice]; ok { n.refs++ - m[addr] = n + d[bindToDevice] = n } else { - m[addr] = portNode{reuse: reuse, refs: 1} + d[bindToDevice] = portNode{reuse: reuse, refs: 1} } } @@ -184,22 +222,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - n, ok := m[addr] + d, ok := m[addr] + if !ok { + continue + } + n, ok := d[bindToDevice] if !ok { continue } n.refs-- + d[bindToDevice] = n if n.refs == 0 { + delete(d, bindToDevice) + } + if len(d) == 0 { delete(m, addr) - } else { - m[addr] = n } if len(m) == 0 { delete(s.allocatedPorts, desc) diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 689401661..a67e283f1 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -34,6 +34,7 @@ type portReserveTestAction struct { want *tcpip.Error reuse bool release bool + device tcpip.NICID } func TestPortReservation(t *testing.T) { @@ -100,6 +101,112 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, release: true}, {port: 24, ip: anyIPAddress, reuse: false, want: nil}, }, + }, { + tname: "bind twice with device fails", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 3, want: nil}, + {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 1, want: nil}, + {port: 24, ip: fakeIPAddress, device: 2, want: nil}, + }, + }, { + tname: "bind to device and then without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + }, + }, { + tname: "binding with reuse and device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "mixing reuse and not reuse by binding to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: nil}, + }, + }, { + tname: "can't bind to 0 after mixing reuse and not reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind and release", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + + // Release the bind to device 0 and try again. + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + }, + }, { + tname: "bind twice with reuse once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "release an unreserved device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + // The below don't exist. + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + // Release all. + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -108,12 +215,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 28c49e8ff..3842f1f7d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,6 +54,7 @@ go_test( size = "small", srcs = [ "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ @@ -64,6 +65,9 @@ go_test( "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0e8a23f00..f6106f762 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -34,8 +34,6 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - mu sync.RWMutex spoofing bool promiscuous bool @@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -707,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) if len(vv.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() @@ -723,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } @@ -767,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { return } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 18d1704a5..6a8079823 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1033,73 +1033,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..8c768c299 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,25 +35,109 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint } +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + } + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + mpep.handlePacketAll(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { - return - } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return } delete(eps.endpoints, id) } @@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. type multiPortEndpoint struct { mu sync.RWMutex endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool } // reciprocalScale scales a value into range [0, n). @@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket modifies vv, so each endpoint needs its own copy except for + // the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, vv) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + vvCopy := buffer.NewView(vv.Size()) + copy(vvCopy, vv.ToView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) -} - -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a broadcast, then find all matching transport endpoints. // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) + destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { @@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.handlePacket(r, id, vv) } return true @@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, vv) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { return ep diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..210233dc0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,352 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicid := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 56e8a5d9b..842a16277 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -127,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -168,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c021c67ac..faaa4a4e3 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -495,6 +495,10 @@ type ReuseAddressOption int // to be bound to an identical socket address. type ReusePortOption int +// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets +// should bind only on a specific NIC. +type BindToDeviceOption string + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a111fdb2a..a3a910d41 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -104,7 +104,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -543,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0802e984e..3ae4a5426 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 35b489c68..a1cd0d481 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -280,6 +280,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -564,11 +567,11 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -625,12 +628,12 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1060,6 +1063,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -1260,6 +1278,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1466,7 +1494,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if e.id.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1480,13 +1508,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if sameAddr && p == e.id.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } id := e.id id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: e.id = id return true, nil @@ -1504,7 +1534,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1648,7 +1678,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1729,7 +1759,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1739,16 +1769,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.id.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.id.LocalPort = 0 e.id.LocalAddress = "" e.boundNICID = 0 } - }() + }(e.bindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2be094876..089826a88 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -465,6 +465,66 @@ func TestSimpleReceive(t *testing.T) { ) } +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -2970,6 +3030,62 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d3f1d2cdf..ef823e4ae 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + if testing.Verbose() { + wep2 = sniffer.New(channel.New(1000, mtu, "")) + } + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -588,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -606,6 +609,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c1ca22b35..7a635ab8d 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -52,6 +52,7 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0bec7e62d..52f5af777 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -88,6 +88,7 @@ type endpoint struct { multicastNICID tcpip.NICID multicastLoop bool reusePort bool + bindToDevice tcpip.NICID broadcast bool // shutdownFlags represent the current shutdown state of the endpoint. @@ -144,8 +145,8 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { @@ -551,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.reusePort = v != 0 e.mu.Unlock() + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -646,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = tcpip.BindToDeviceOption("") + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -753,12 +779,12 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.id.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.id = id e.route.Release() e.route = stack.Route{} @@ -835,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) } e.id = id @@ -898,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } return id, err } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a9edc2c8d..2d0bc5221 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { ep.Close() return nil, err } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 2ec27be4d..5059ca22d 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -17,7 +17,6 @@ package udp_test import ( "bytes" "fmt" - "math" "math/rand" "testing" "time" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -476,87 +476,59 @@ func newMinPayload(minSize int) []byte { return b } -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) - - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) } + defer ep.Close() - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.injectV6Packet(payload, &header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - }) + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) } - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s } - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) } } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 28b23ce58..e645eebfa 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2463,6 +2463,63 @@ cc_binary( ], ) +cc_binary( + name = "socket_bind_to_device_test", + testonly = 1, + srcs = [ + "socket_bind_to_device.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_sequence_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_sequence.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_distribution_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_distribution.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + cc_binary( name = "socket_ip_udp_loopback_non_blocking_test", testonly = 1, @@ -2740,6 +2797,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_bind_to_device_util", + testonly = 1, + srcs = [ + "socket_bind_to_device_util.cc", + ], + hdrs = [ + "socket_bind_to_device_util.h", + ], + deps = [ + "//test/util:test_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_binary( name = "socket_stream_local_test", testonly = 1, @@ -3253,6 +3327,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:uid_util", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc new file mode 100644 index 000000000..d20821cac --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +// Test fixture for SO_BINDTODEVICE tests. +class BindToDeviceTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + + interface_name_ = "eth1"; + auto interface_names = GetInterfaceNames(); + if (interface_names.find(interface_name_) == interface_names.end()) { + // Need a tunnel. + tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()); + interface_name_ = tunnel_->GetName(); + ASSERT_FALSE(interface_name_.empty()); + } + socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create()); + } + + string interface_name() const { return interface_name_; } + + int socket_fd() const { return socket_->get(); } + + private: + std::unique_ptr tunnel_; + string interface_name_; + std::unique_ptr socket_; +}; + +constexpr char kIllegalIfnameChar = '/'; + +// Tests getsockopt of the default value. +TEST_P(BindToDeviceTest, GetsockoptDefault) { + char name_buffer[IFNAMSIZ * 2]; + char original_name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Read the default SO_BINDTODEVICE. + memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(name_buffer_size, 0); + EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)), + 0); + } +} + +// Tests setsockopt of invalid device name. +TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Set an invalid device name. + memset(name_buffer, kIllegalIfnameChar, 5); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallFailsWithErrno(ENODEV)); +} + +// Tests setsockopt of a buffer with a valid device name but not +// null-terminated, with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Intentionally overwrite the null at the end. + memset(name_buffer + interface_name().size(), kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size()); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is exactly right. + if (name_buffer_size == interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests setsockopt of a buffer with a valid device name and null-terminated, +// with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Don't overwrite the null at the end. + memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size() - 1); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is at least the right size. + if (name_buffer_size >= interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests that setsockopt of an invalid device name doesn't unset the previous +// valid setsockopt. +TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Write unsuccessfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallFailsWithErrno(ENODEV)); + + // Read it back successfully, it's unchanged. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); +} + +// Tests that setsockopt of zero-length string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClear) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + name_buffer_size = 0; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests that setsockopt of empty string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer[0] = 0; + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests getsockopt with different buffer sizes. +TEST_P(BindToDeviceTest, GetsockoptDevice) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back at various buffer sizes. + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice + // for this interface name. + if (name_buffer_size >= IFNAMSIZ) { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + } else { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallFailsWithErrno(EINVAL)); + EXPECT_EQ(name_buffer_size, i); + } + } +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc new file mode 100644 index 000000000..4d2400328 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -0,0 +1,381 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +struct EndpointConfig { + std::string bind_to_device; + double expected_ratio; +}; + +struct DistributionTestCase { + std::string name; + std::vector endpoints; +}; + +struct ListenerConnector { + TestAddress listener; + TestAddress connector; +}; + +// Test fixture for SO_BINDTODEVICE tests the distribution of packets received +// with varying SO_BINDTODEVICE settings. +class BindToDeviceDistributionTest + : public ::testing::TestWithParam< + ::testing::tuple> { + protected: + void SetUp() override { + printf("Testing case: %s, listener=%s, connector=%s\n", + ::testing::get<1>(GetParam()).name.c_str(), + ::testing::get<0>(GetParam()).listener.description.c_str(), + ::testing::get<0>(GetParam()).connector.description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + } +}; + +PosixErrorOr AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast( + reinterpret_cast(&addr)->sin_port); + case AF_INET6: + return static_cast( + reinterpret_cast(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +// Binds sockets to different devices and then creates many TCP connections. +// Checks that the distribution of connections received on the sockets matches +// the expectation. +TEST_P(BindToDeviceDistributionTest, Tcp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening sockets. + std::vector listener_fds; + std::vector> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic connects_received = ATOMIC_VAR_INIT(0); + std::vector accept_counts(listener_fds.size(), 0); + std::vector> listen_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + listen_threads[i] = absl::make_unique( + [&listener_fds, &accept_counts, &connects_received, i, + kConnectAttempts]() { + do { + auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); + if (!fd.ok()) { + // Another thread has shutdown our read side causing the accept to + // fail. + ASSERT_GE(connects_received, kConnectAttempts) + << "errno = " << fd.error(); + return; + } + // Receive some data from a socket to be sure that the connect() + // system call has been completed on another side. + int data; + EXPECT_THAT( + RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + accept_counts[i]++; + } while (++connects_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(fd.get(), reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& listen_thread : listen_threads) { + listen_thread->Join(); + } + // Check that connections are distributed correctly among listening sockets. + for (int i = 0; i < accept_counts.size(); i++) { + EXPECT_THAT( + accept_counts[i], + EquivalentWithin(static_cast(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +// Binds sockets to different devices and then sends many UDP packets. Checks +// that the distribution of packets received on the sockets matches the +// expectation. +TEST_P(BindToDeviceDistributionTest, Udp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening socket. + std::vector listener_fds; + std::vector> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic packets_received = ATOMIC_VAR_INIT(0); + std::vector packets_per_socket(listener_fds.size(), 0); + std::vector> receiver_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + receiver_threads[i] = absl::make_unique( + [&listener_fds, &packets_per_socket, &packets_received, i]() { + do { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + + auto ret = RetryEINTR(recvfrom)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast(&addr), &addrlen); + + if (packets_received < kConnectAttempts) { + ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); + } + + if (ret != sizeof(data)) { + // Another thread may have shutdown our read side causing the + // recvfrom to fail. + break; + } + + packets_received++; + packets_per_socket[i]++; + + // A response is required to synchronize with the main thread, + // otherwise the main thread can send more than can fit into receive + // queues. + EXPECT_THAT(RetryEINTR(sendto)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); + } while (packets_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + int data; + EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& receiver_thread : receiver_threads) { + receiver_thread->Join(); + } + // Check that packets are distributed correctly among listening sockets. + for (int i = 0; i < packets_per_socket.size(); i++) { + EXPECT_THAT( + packets_per_socket[i], + EquivalentWithin(static_cast(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +std::vector GetDistributionTestCases() { + return std::vector{ + {"Even distribution among sockets not bound to device", + {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}}, + {"Sockets bound to other interfaces get no packets", + {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}}, + {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}}, + {"Even distribution among sockets bound to device", + {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}}, + }; +} + +INSTANTIATE_TEST_SUITE_P( + BindToDeviceTest, BindToDeviceDistributionTest, + ::testing::Combine(::testing::Values( + // Listeners bound to IPv4 addresses refuse + // connections using IPv6 addresses. + ListenerConnector{V4Any(), V4Loopback()}, + ListenerConnector{V4Loopback(), V4MappedLoopback()}), + ::testing::ValuesIn(GetDistributionTestCases()))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc new file mode 100644 index 000000000..a7365d139 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -0,0 +1,316 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket +// binding. +class BindToDeviceSequenceTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + socket_factory_ = GetParam(); + + interface_names_ = GetInterfaceNames(); + } + + PosixErrorOr> NewSocket() const { + return socket_factory_.Create(); + } + + // Gets a device by device_id. If the device_id has been seen before, returns + // the previously returned device. If not, finds or creates a new device. + // Returns an empty string on failure. + void GetDevice(int device_id, string *device_name) { + auto device = devices_.find(device_id); + if (device != devices_.end()) { + *device_name = device->second; + return; + } + + // Need to pick a new device. Try ethernet first. + *device_name = absl::StrCat("eth", next_unused_eth_); + if (interface_names_.find(*device_name) != interface_names_.end()) { + devices_[device_id] = *device_name; + next_unused_eth_++; + return; + } + + // Need to make a new tunnel device. gVisor tests should have enough + // ethernet devices to never reach here. + ASSERT_FALSE(IsRunningOnGvisor()); + // Need a tunnel. + tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New())); + devices_[device_id] = tunnels_.back()->GetName(); + *device_name = devices_[device_id]; + } + + // Release the socket + void ReleaseSocket(int socket_id) { + // Close the socket that was made in a previous action. The socket_id + // indicates which socket to close based on index into the list of actions. + sockets_to_close_.erase(socket_id); + } + + // Bind a socket with the reuse option and bind_to_device options. Checks + // that all steps succeed and that the bind command's error matches want. + // Sets the socket_id to uniquely identify the socket bound if it is not + // nullptr. + void BindSocket(bool reuse, int device_id = 0, int want = 0, + int *socket_id = nullptr) { + next_socket_id_++; + sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket_fd = sockets_to_close_[next_socket_id_]->get(); + if (socket_id != nullptr) { + *socket_id = next_socket_id_; + } + + // If reuse is indicated, do that. + if (reuse) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + + // If the device is non-zero, bind to that device. + if (device_id != 0) { + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + char get_device[100]; + socklen_t get_device_size = 100; + EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device, + &get_device_size), + SyscallSucceedsWithValue(0)); + } + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr.sin_port = port_; + if (want == 0) { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast(&addr), + sizeof(addr)), + SyscallSucceeds()); + } else { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast(&addr), + sizeof(addr)), + SyscallFailsWithErrno(want)); + } + + if (port_ == 0) { + // We don't yet know what port we'll be using so we need to fetch it and + // remember it for future commands. + socklen_t addr_size = sizeof(addr); + ASSERT_THAT( + getsockname(socket_fd, reinterpret_cast(&addr), + &addr_size), + SyscallSucceeds()); + port_ = addr.sin_port; + } + } + + private: + SocketKind socket_factory_; + // devices maps from the device id in the test case to the name of the device. + std::unordered_map devices_; + // These are the tunnels that were created for the test and will be destroyed + // by the destructor. + vector> tunnels_; + // A list of all interface names before the test started. + std::unordered_set interface_names_; + // The next ethernet device to use when requested a device. + int next_unused_eth_ = 1; + // The port for all tests. Originally 0 (any) and later set to the port that + // all further commands will use. + in_port_t port_ = 0; + // sockets_to_close_ is a map from action index to the socket that was + // created. + std::unordered_map> + sockets_to_close_; + int next_socket_id_ = 0; +}; + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 2)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuse) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); +} + +TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindAndRelease) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + int to_release; + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + // Release the bind to device 0 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345)); +} + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc new file mode 100644 index 000000000..f4ee775bd --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.cc @@ -0,0 +1,75 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_bind_to_device_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +PosixErrorOr> Tunnel::New(string tunnel_name) { + int fd; + RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR)); + + // Using `new` to access a non-public constructor. + auto new_tunnel = absl::WrapUnique(new Tunnel(fd)); + + ifreq ifr = {}; + ifr.ifr_flags = IFF_TUN; + strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name)); + + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr)); + new_tunnel->name_ = ifr.ifr_name; + return new_tunnel; +} + +std::unordered_set GetInterfaceNames() { + struct if_nameindex* interfaces = if_nameindex(); + std::unordered_set names; + if (interfaces == nullptr) { + return names; + } + for (auto interface = interfaces; + interface->if_index != 0 || interface->if_name != nullptr; interface++) { + names.insert(interface->if_name); + } + if_freenameindex(interfaces); + return names; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h new file mode 100644 index 000000000..f941ccc86 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.h @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ +#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +class Tunnel { + public: + static PosixErrorOr> New( + std::string tunnel_name = ""); + const std::string& GetName() const { return name_; } + + ~Tunnel() { + if (fd_ != -1) { + close(fd_); + } + } + + private: + Tunnel(int fd) : fd_(fd) {} + int fd_ = -1; + std::string name_; +}; + +std::unordered_set GetInterfaceNames(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index d48453a93..6218fbce1 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -25,6 +25,7 @@ #include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/uid_util.h" ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); @@ -68,30 +69,6 @@ TEST(UidGidTest, Getgroups) { // here; see the setgroups test below. } -// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns -// true. Otherwise IsRoot logs an explanatory message and returns false. -PosixErrorOr IsRoot() { - uid_t ruid, euid, suid; - int rc = getresuid(&ruid, &euid, &suid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresuid"); - } - if (ruid != 0 || euid != 0 || suid != 0) { - return false; - } - gid_t rgid, egid, sgid; - rc = getresgid(&rgid, &egid, &sgid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresgid"); - } - if (rgid != 0 || egid != 0 || sgid != 0) { - return false; - } - return true; -} - // Checks that the calling process' real/effective/saved user IDs are // ruid/euid/suid respectively. PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) { diff --git a/test/util/BUILD b/test/util/BUILD index 25ed9c944..5d2a9cc2c 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -324,3 +324,14 @@ cc_library( ":test_util", ], ) + +cc_library( + name = "uid_util", + testonly = 1, + srcs = ["uid_util.cc"], + hdrs = ["uid_util.h"], + deps = [ + ":posix_error", + ":save_util", + ], +) diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc new file mode 100644 index 000000000..b131b4b99 --- /dev/null +++ b/test/util/uid_util.cc @@ -0,0 +1,44 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/posix_error.h" +#include "test/util/save_util.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr IsRoot() { + uid_t ruid, euid, suid; + int rc = getresuid(&ruid, &euid, &suid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresuid"); + } + if (ruid != 0 || euid != 0 || suid != 0) { + return false; + } + gid_t rgid, egid, sgid; + rc = getresgid(&rgid, &egid, &sgid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresgid"); + } + if (rgid != 0 || egid != 0 || sgid != 0) { + return false; + } + return true; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/uid_util.h b/test/util/uid_util.h new file mode 100644 index 000000000..2cd387fb0 --- /dev/null +++ b/test/util/uid_util.h @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_ + +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// Returns true if the caller's real/effective/saved user/group IDs are all 0. +PosixErrorOr IsRoot(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_ -- cgit v1.2.3