diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 155 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 48 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 156 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 95 |
7 files changed, 185 insertions, 364 deletions
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7a501acdc..cb7dec1ea 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + netHdr := pkt.NetworkHeader().View() + _, dst := f.proto.ParseAddresses(netHdr) + + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) + return + } + + r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) + if err != nil { + return + } + defer r.Release() + + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + pkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: vv.ToView().ToVectorisedView(), + }) + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + _ = r.WriteHeaderIncludedPacket(pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + // The network header should not already be populated. + if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { + return tcpip.ErrMalformedHeader + } + + return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() { // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { + stack *Stack + addrCache *linkAddrCache neigh *neighborCache addrResolveDelay time.Duration @@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { + proto.stack = s + return proto + }}, UseNeighborCache: useNeighborCache, }) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 60c81a3aa..3e6ceff28 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -232,7 +232,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -320,16 +321,21 @@ func (n *NIC) setSpoofing(enable bool) { // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { - n.mu.RLock() - spoofing := n.mu.spoofing - n.mu.RUnlock() - ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + + return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -388,11 +394,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { - if ep, ok := n.networkEndpoints[protocol]; ok { - return ep.AcquireAssignedAddress(address, createTemp, peb) + ep, ok := n.networkEndpoints[protocol] + if !ok { + return nil } - return nil + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb) } // addAddress adds a new address to n, so that it starts accepting packets @@ -403,7 +415,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return tcpip.ErrUnknownProtocol } - addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -416,7 +433,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PermanentAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PermanentAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -427,7 +449,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PrimaryAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PrimaryAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -445,13 +472,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } - return ep.MainAddress() + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.AddressWithPrefix{} + } + + return addressableEndpoint.MainAddress() } // removeAddress removes an address from n. func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { for _, ep := range n.networkEndpoints { - if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { continue } else { return err @@ -564,13 +601,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.PopulatePacketInfo(pkt) - n.getNetworkEndpoint(protocol).HandlePacket(pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -592,7 +622,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -617,11 +647,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp ep.HandlePacket(n.id, local, protocol, p) } - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() - } - // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -636,9 +663,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -651,78 +677,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.localAddressNIC - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - pkt.NICID = n.ID() - r.RemoteAddress = src - pkt.NetworkPacketInfo = r.networkPacketInfo() - n.getNetworkEndpoint(protocol).HandlePacket(pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease - // the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 00e9a82ae..43ca03ada 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -65,10 +65,6 @@ const ( // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { - // RemoteAddressBroadcast is true if the packet's remote address is a - // broadcast address. - RemoteAddressBroadcast bool - // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool @@ -266,10 +262,10 @@ const ( // NetOptions is an interface that allows us to pass network protocol specific // options through the Stack layer code. type NetOptions interface { - // AllocationSize returns the amount of memory that must be allocated to + // SizeWithPadding returns the amount of memory that must be allocated to // hold the options given that the value must be rounded up to the next // multiple of 4 bytes. - AllocationSize() int + SizeWithPadding() int } // NetworkHeaderParams are the header parameters given as input by the @@ -518,6 +514,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + // Promiscuous returns true if the interface is in promiscuous mode. + Promiscuous() bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } @@ -525,8 +524,6 @@ type NetworkInterface interface { // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { - AddressableEndpoint - // Enable enables the endpoint. // // Must only be called when the stack is in a state that allows the endpoint @@ -742,10 +739,6 @@ type LinkEndpoint interface { // endpoint. Capabilities() LinkEndpointCapabilities - // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. It takes ownership of vv. - WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error - // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 15ff437c7..53cb6694f 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -170,28 +170,6 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -// PopulatePacketInfo populates a packet buffer's packet information fields. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { - if r.local() { - pkt.RXTransportChecksumValidated = true - } - pkt.NetworkPacketInfo = r.networkPacketInfo() -} - -// networkPacketInfo returns the network packet information of the route. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) networkPacketInfo() NetworkPacketInfo { - return NetworkPacketInfo{ - RemoteAddressBroadcast: r.IsOutboundBroadcast(), - LocalAddressBroadcast: r.isInboundBroadcast(), - } -} - // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { return r.outgoingNIC.ID() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0fe157128..f4504e633 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -82,6 +82,7 @@ type TCPRACKState struct { FACK seqnum.Value RTT time.Duration Reord bool + DSACKSeen bool } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -1080,7 +1081,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. Running: nic.Enabled(), - Promiscuous: nic.isPromiscuousMode(), + Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ @@ -1809,49 +1810,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip nic.unregisterPacketEndpoint(netProto, ep) } -// WritePacket writes data directly to the specified NIC. It adds an ethernet -// header based on the arguments. -func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +// WritePacketToRemote writes a payload on the specified NIC using the provided +// network protocol and remote link address. +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice } - - // Add our own fake ethernet header. - ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), - DstAddr: dst, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - vv := buffer.View(fakeHeader).ToVectorisedView() - vv.Append(payload) - - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { - return err - } - - return nil -} - -// WriteRawPacket writes data directly to the specified NIC without adding any -// headers. -func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { - s.mu.Lock() - nic, ok := s.nics[nicID] - s.mu.Unlock() - if !ok { - return tcpip.ErrUnknownDevice - } - - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { - return err - } - - return nil + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()), + Data: payload, + }) + return nic.WritePacketToRemote(remote, nil, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dedfdd435..0d94af139 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -112,7 +112,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() - f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ + + dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + return + } + addressEndpoint.DecRef() + + f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { @@ -159,9 +167,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.Clone() - r.PopulatePacketInfo(pkt) - f.HandlePacket(pkt) + f.HandlePacket(pkt.Clone()) } if r.Loop&stack.PacketOut == 0 { return nil @@ -2214,88 +2220,6 @@ func TestNICStats(t *testing.T) { } } -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { - t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -4228,3 +4152,63 @@ func TestFindRouteWithForwarding(t *testing.T) { }) } } + +func TestWritePacketToRemote(t *testing.T) { + const nicID = 1 + const MTU = 1280 + e := channel.New(1, MTU, linkAddr1) + s := stack.New(stack.Options{}) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("CreateNIC(%d) = %s", nicID, err) + } + tests := []struct { + name string + protocol tcpip.NetworkProtocolNumber + payload []byte + }{ + { + name: "SuccessIPv4", + protocol: header.IPv4ProtocolNumber, + payload: []byte{1, 2, 3, 4}, + }, + { + name: "SuccessIPv6", + protocol: header.IPv6ProtocolNumber, + payload: []byte{5, 6, 7, 8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) + } + + pkt, ok := e.Read() + if got, want := ok, true; got != want { + t.Fatalf("e.Read() = %t, want %t", got, want) + } + if got, want := pkt.Proto, test.protocol; got != want { + t.Fatalf("pkt.Proto = %d, want %d", got, want) + } + if got, want := pkt.Route.RemoteLinkAddress, linkAddr2; got != want { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + } + if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) + } + }) + } + + t.Run("InvalidNICID", func(t *testing.T) { + if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + } + pkt, ok := e.Read() + if got, want := ok, false; got != want { + t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) + } + }) +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index c457b67a2..5b9043d85 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -47,6 +46,9 @@ type fakeTransportEndpoint struct { // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint + + // ops is used to set and get socket options. + ops tcpip.SocketOptions } func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { @@ -59,6 +61,9 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { + return &f.ops +} func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } @@ -184,9 +189,9 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai if len(f.acceptQueue) == 0 { return nil, nil, nil } - a := f.acceptQueue[0] + a := &f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil + return a, nil, nil } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { @@ -553,87 +558,3 @@ func TestTransportOptions(t *testing.T) { t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: req.ToVectorisedView(), - })) - - aep, _, err := ep.Accept(nil) - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - nh := stack.PayloadSince(p.Pkt.NetworkHeader()) - if dst := nh[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := nh[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} |