diff options
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 118 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 203 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 51 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 221 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 2 |
10 files changed, 580 insertions, 143 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index b47a7be51..7df77c66e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -49,7 +49,6 @@ type endpoint struct { enabled uint32 nic stack.NetworkInterface - linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler } @@ -92,12 +91,12 @@ func (e *endpoint) DefaultTTL() uint8 { } func (e *endpoint) MTU() uint32 { - lmtu := e.linkEP.MTU() + lmtu := e.nic.MTU() return lmtu - uint32(e.MaxHeaderLength()) } func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.ARPSize + return e.nic.MaxHeaderLength() + header.ARPSize } func (e *endpoint) Close() { @@ -154,17 +153,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + origSender := h.HardwareAddressSender() + r.RemoteLinkAddress = tcpip.LinkAddress(origSender) + respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) - packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.HardwareAddressTarget(), origSender) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -207,7 +214,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e := &endpoint{ protocol: p, nic: nic, - linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } @@ -223,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ + NetProto: ProtocolNumber, RemoteLinkAddress: remoteLinkAddr, } if len(r.RemoteLinkAddress) == 0 { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6861cfdaf..d436873b6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -270,7 +270,7 @@ func buildDummyStack(t *testing.T) *stack.Stack { var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - tester testObject + testObject mu struct { sync.RWMutex @@ -302,10 +302,6 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { - return &t.tester -} - func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -517,7 +513,7 @@ func TestIPv4Send(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, @@ -538,10 +534,10 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - nic.tester.protocol = 123 - nic.tester.srcAddr = localIPv4Addr - nic.tester.dstAddr = remoteIPv4Addr - nic.tester.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv4Addr + nic.testObject.dstAddr = remoteIPv4Addr + nic.testObject.contents = payload r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -560,12 +556,12 @@ func TestIPv4Receive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -590,10 +586,10 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = view[header.IPv4MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -606,8 +602,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -640,11 +636,11 @@ func TestIPv4ReceiveControl(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -691,16 +687,16 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = view[dataOffset:] - nic.tester.typ = c.expectedTyp - nic.tester.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; nic.tester.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } @@ -710,12 +706,12 @@ func TestIPv4FragmentationReceive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -758,10 +754,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -776,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) + if nic.testObject.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } // Send second segment. @@ -788,8 +784,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -797,7 +793,7 @@ func TestIPv6Send(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } @@ -821,10 +817,10 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - nic.tester.protocol = 123 - nic.tester.srcAddr = localIPv6Addr - nic.tester.dstAddr = remoteIPv6Addr - nic.tester.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv6Addr + nic.testObject.dstAddr = remoteIPv6Addr + nic.testObject.contents = payload r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { @@ -843,11 +839,11 @@ func TestIPv6Receive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -871,10 +867,10 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv6Addr - nic.tester.dstAddr = localIPv6Addr - nic.tester.contents = view[header.IPv6MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { @@ -888,8 +884,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -931,11 +927,11 @@ func TestIPv6ReceiveControl(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -994,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv6Addr - nic.tester.dstAddr = localIPv6Addr - nic.tester.contents = view[dataOffset:] - nic.tester.typ = c.expectedTyp - nic.tester.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; nic.tester.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index ee2c23e91..7fc12e229 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -32,6 +32,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index eab9a530c..3407755ed 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -102,8 +102,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). @@ -119,9 +117,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } defer r.Release() - // Use the remote link address from the incoming packet. - r.ResolveWith(remoteLinkAddr) - // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -244,13 +239,7 @@ func (*icmpReasonProtoUnreachable) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - sent := r.Stats().ICMP.V4PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -279,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc return nil } + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + sent := p.stack.Stats().ICMP.V4PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() @@ -329,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // least 8 bytes of the payload must be included. Today linux and other // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv4MinimumProcessableDatagramSize { mtu = header.IPv4MinimumProcessableDatagramSize } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { @@ -378,11 +386,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) counter := sent.DstUnreachable - if err := r.WritePacket( + if err := route.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), + TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, }, icmpPkt, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 99274dd45..115fb1ab0 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -66,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol @@ -87,7 +86,6 @@ type endpoint struct { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, } @@ -178,18 +176,18 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize + return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // GSOMaxSize returns the maximum GSO packet size. func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + if gso, ok := e.nic.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -210,7 +208,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint for { fragPkt, more := buildNextFragment(&pf, networkHeader) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) return err } @@ -283,10 +281,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, e.linkEP.MTU(), pkt) + if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -316,7 +314,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) @@ -343,7 +341,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in @@ -404,7 +402,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } - if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -512,13 +510,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index f250a3cde..9916d783f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,6 +16,7 @@ package ipv4_test import ( "bytes" + "context" "encoding/hex" "math" "net" @@ -28,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -1492,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, lm.limit-- return false, false } + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, + TTL: ipv4.DefaultTTL, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a ARP request since link address resolution should be + // performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != arp.ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) + } + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(p.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) + } + } + + // Send an ARP reply to complete link address resolution. + { + hdr := buffer.View(make([]byte, header.ARPSize)) + packet := header.ARP(hdr) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), host2NICLinkAddr) + copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) + copy(packet.HardwareAddressTarget(), host1NICLinkAddr) + copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) + e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 37c169a5d..7be35c78b 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -440,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := r.LocalAddress @@ -456,9 +454,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } defer r.Release() - // Use the link address from the source of the original packet. - r.ResolveWith(remoteLinkAddr) - replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data, @@ -742,14 +737,7 @@ func (*icmpReasonPortUnreachable) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - stats := r.Stats().ICMP - sent := stats.V6PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -780,6 +768,26 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc return nil } + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + stats := p.stack.Stats().ICMP + sent := stats.V6PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { @@ -806,11 +814,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // packet that caused the error) as possible without making // the error message packet exceed the minimum IPv6 MTU // [IPv6]. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv6MinimumMTU { mtu = header.IPv6MinimumMTU } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize available := int(mtu) - headerLen if available < header.IPv6MinimumSize { return nil @@ -843,9 +851,16 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) - err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) - if err != nil { + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) + if err := route.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + newPkt, + ); err != nil { sent.Dropped.Increment() return err } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 6f13339a4..3affcc4e4 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,9 +16,11 @@ package ipv6 import ( "context" + "net" "reflect" "strings" "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -28,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -40,6 +43,9 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 30 * time.Second ) var ( @@ -110,7 +116,9 @@ func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { var _ stack.NetworkInterface = (*testInterface)(nil) -type testInterface struct{} +type testInterface struct { + stack.NetworkLinkEndpoint +} func (*testInterface) ID() tcpip.NICID { return 0 @@ -128,10 +136,6 @@ func (*testInterface) Enabled() bool { return true } -func (*testInterface) LinkEndpoint() stack.LinkEndpoint { - return nil -} - func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -1281,3 +1285,210 @@ func TestLinkAddressRequest(t *testing.T) { )) } } + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + checker.ICMPv6Code(header.ICMPv6UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a neighbor solicitation since link address resolution should + // be performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), + )) + } + + // Send a neighbor advertisement to complete link address resolution. + { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), + }) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 826342c4f..4f360df2c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -66,7 +66,6 @@ var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher @@ -364,18 +363,18 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } // GSOMaxSize returns the maximum GSO packet size. func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + if gso, ok := e.nic.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -396,7 +395,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s } func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) + return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) } // handleFragments fragments pkt and calls the handler function on each @@ -477,19 +476,19 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.linkEP.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. - return e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt) + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -511,7 +510,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe e.addIPHeader(r, pb, params) if e.packetMustBeFragmented(pb, gso) { current := pb - _, _, err := e.handleFragments(r, gso, e.linkEP.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(current, fragPkt) current = current.Next() @@ -536,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) @@ -563,7 +562,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in @@ -643,7 +642,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -682,7 +681,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -707,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -859,7 +858,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -896,7 +895,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -917,7 +916,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -927,7 +926,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -1302,7 +1301,6 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 48a4c65e3..40da011f8 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.ep.linkEP.LinkAddress() + linkAddr := ndp.ep.nic.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } |