diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 239 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 222 |
2 files changed, 246 insertions, 215 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c55e3e8bc..9a33ed375 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,7 +35,7 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]*endpointsByNic + endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint @@ -46,11 +46,11 @@ type transportEndpoints struct { func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { return } delete(eps.endpoints, id) @@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { return es } -type endpointsByNic struct { +// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { + // Try to find a match with the id as provided. + if ep, ok := eps.endpoints[id]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the local address. + nid := id + + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } +} + +// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in +// descending order of match quality. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { + var matchedEPs []*endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given id. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { + var matchedEP *endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEP = ep + return false + }) + return matchedEP +} + +type endpointsByNIC struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 } -func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() var eps []TransportEndpoint - for _, ep := range epsByNic.endpoints { + for _, ep := range epsByNIC.endpoints { eps = append(eps, ep.transportEndpoints()...) } return eps @@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { - epsByNic.mu.RLock() +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { + epsByNIC.mu.RLock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } } @@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNic.seed) + transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(r, transEP, id, pkt) - epsByNic.mu.RUnlock() + epsByNIC.mu.RUnlock() return } transEP.HandlePacket(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() - mpep, ok := epsByNic.endpoints[n.ID()] + mpep, ok := epsByNIC.endpoints[n.ID()] if !ok { - mpep, ok = epsByNic.endpoints[0] + mpep, ok = epsByNIC.endpoints[0] } if !ok { return @@ -132,16 +199,16 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { multiPortEp = &multiPortEndpoint{ demux: d, @@ -149,24 +216,24 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t transProto: transProto, reuse: reusePort, } - epsByNic.endpoints[bindToDevice] = multiPortEp + epsByNIC.endpoints[bindToDevice] = multiPortEp } return multiPortEp.singleRegisterEndpoint(t, reusePort) } -// unregisterEndpoint returns true if endpointsByNic has to be unregistered. -func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] +// unregisterEndpoint returns true if endpointsByNIC has to be unregistered. +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } if multiPortEp.unregisterEndpoint(t) { - delete(epsByNic.endpoints, bindToDevice) + delete(epsByNIC.endpoints, bindToDevice) } - return len(epsByNic.endpoints) == 0 + return len(epsByNIC.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -198,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for proto := range stack.transportProtocols { protoIDs := protocolIDs{netProto, proto} d.protocol[protoIDs] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]*endpointsByNic), + endpoints: make(map[TransportEndpointID]*endpointsByNIC), } qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) if isQueued { @@ -378,16 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { - epsByNic = &endpointsByNic{ + epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), seed: rand.Uint32(), } - eps.endpoints[id] = epsByNic + eps.endpoints[id] = epsByNIC } - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -413,7 +480,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // transport endpoints. if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { eps.mu.RLock() - destEPs := d.findAllEndpointsLocked(eps, id) + destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { @@ -439,7 +506,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { @@ -483,115 +550,47 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() - - // Fail if we didn't find one. if ep == nil { return false } - // Deliver the packet. ep.handleControlPacket(n, id, typ, extra, pkt) - return true } -// iterEndpointsLocked yields all endpointsByNic in eps that match id, in -// descending order of match quality. If a call to yield returns false, -// iterEndpointsLocked stops iteration and returns immediately. -// -// Preconditions: eps.mu must be locked. -func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) { - // Try to find a match with the id as provided. - if ep, ok := eps.endpoints[id]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with the id minus the local address. - nid := id - - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with the id minus the remote part. - nid.LocalAddress = id.LocalAddress - nid.RemoteAddress = "" - nid.RemotePort = 0 - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with only the local port. - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } -} - -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic - d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { - matchedEPs = append(matchedEPs, ep) - return true - }) - return matchedEPs -} - // findTransportEndpoint find a single endpoint that most closely matches the provided id. func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } - // Try to find the endpoint. + eps.mu.RLock() - epsByNic := d.findEndpointLocked(eps, id) - // Fail if we didn't find one. - if epsByNic == nil { + epsByNIC := eps.findEndpointLocked(id) + if epsByNIC == nil { eps.mu.RUnlock() return nil } - epsByNic.mu.RLock() + epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return nil } } - ep := selectEndpoint(id, mpep, epsByNic.seed) - epsByNic.mu.RUnlock() + ep := selectEndpoint(id, mpep, epsByNIC.seed) + epsByNIC.mu.RUnlock() return ep } -// findEndpointLocked returns the endpoint that most closely matches the given -// id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - var matchedEP *endpointsByNic - d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { - matchedEP = ep - return false - }) - return matchedEP -} - // registerRawEndpoint registers the given endpoint with the dispatcher such // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 84311bcc8..c65b0c632 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -31,84 +31,58 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 + testSrcAddrV4 = "\x0a\x00\x00\x01" + testDstAddrV4 = "\x0a\x00\x00\x02" + + testDstPort = 1234 + testSrcPort = 4096 ) type testContext struct { - t *testing.T linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } + wq waiter.Queue } // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { channelEp := channel.New(256, mtu, "") if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { + t.Fatalf("AddAddress IPv4 failed: %s", err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { + t.Fatalf("AddAddress IPv6 failed: %s", err) } } s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, + {Destination: header.IPv4EmptySubnet, NIC: 1}, + {Destination: header.IPv6EmptySubnet, NIC: 1}, }) return &testContext{ - t: t, s: s, linkEps: linkEps, } } type headers struct { - srcPort uint16 - dstPort uint16 + srcPort, dstPort uint16 } func newPayload() []byte { @@ -119,6 +93,47 @@ func newPayload() []byte { return b } +func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: 0x80, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testSrcAddrV4, + DstAddr: testDstAddrV4, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) +} + func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) @@ -130,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -143,7 +158,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -151,7 +166,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Inject packet. c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), }) } @@ -179,15 +196,15 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) } } -// TestReuseBindToDevice injects varied packets on input devices and checks that +// TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { +func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { reuse int bindToDevice tcpip.NICID @@ -196,19 +213,19 @@ func TestDistribution(t *testing.T) { name string // endpoints will received the inject packets. endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each + // wantDistributions is the want ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[tcpip.NICID][]float64 + wantDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -219,9 +236,9 @@ func TestDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, 1}, - {0, 2}, - {0, 3}, + {reuse: 0, bindToDevice: 1}, + {reuse: 0, bindToDevice: 2}, + {reuse: 0, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -236,12 +253,12 @@ func TestDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, 1}, - {1, 1}, - {1, 2}, - {1, 2}, - {1, 2}, - {1, 0}, + {reuse: 1, bindToDevice: 1}, + {reuse: 1, bindToDevice: 1}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -255,17 +272,17 @@ func TestDistribution(t *testing.T) { }, }, } { - t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { - t.Run(string(device), func(t *testing.T) { + for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } { + for device, wantDistribution := range test.wantDistributions { + t.Run(test.name+protoName+string(device), func(t *testing.T) { var devices []tcpip.NICID - for d := range test.wantedDistributions { + for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) eps := make(map[tcpip.Endpoint]int) @@ -279,9 +296,9 @@ func TestDistribution(t *testing.T) { defer close(ch) var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps[ep] = i @@ -294,20 +311,30 @@ func TestDistribution(t *testing.T) { defer ep.Close() reusePortOption := tcpip.ReusePortOption(endpoint.reuse) if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + } + + var dstAddr tcpip.Address + switch netProtoNum { + case ipv4.ProtocolNumber: + dstAddr = testDstAddrV4 + case ipv6.ProtocolNumber: + dstAddr = testDstAddrV6 + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } npackets := 100000 nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { + if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } ports := make(map[uint16]tcpip.Endpoint) @@ -316,17 +343,22 @@ func TestDistribution(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) + hdrs := &headers{ + srcPort: testSrcPort + port, + dstPort: testDstPort, + } + switch netProtoNum { + case ipv4.ProtocolNumber: + c.sendV4Packet(payload, hdrs, device) + case ipv6.ProtocolNumber: + c.sendV6Packet(payload, hdrs, device) + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) + } - var addr tcpip.FullAddress ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + if _, _, err := ep.Read(nil); err != nil { + t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ if i < nports { @@ -342,17 +374,17 @@ func TestDistribution(t *testing.T) { // Check that a packet distribution is as expected. for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) + wantRatio := wantDistribution[i] + wantRecv := wantRatio * float64(npackets) actualRecv := stats[ep] actualRatio := float64(stats[ep]) / float64(npackets) // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + if math.Abs(actualRatio-wantRatio) > 0.05 { + t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) } } }) } - }) + } } } |