diff options
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 43 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 99 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 16 |
6 files changed, 88 insertions, 82 deletions
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 750aa4022..95efada3a 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -688,25 +688,38 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver. func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + nicID := nic.ID() + + p.mu.Lock() + netEP, ok := p.mu.eps[nicID] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + if len(localAddr) == 0 { + addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + if addressEndpoint == nil { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addressEndpoint.AddressWithPrefix().Address + } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } - defer r.Release() - r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -714,20 +727,18 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{})) - p.mu.Lock() - netEP, ok := p.mu.eps[nic.ID()] - p.mu.Unlock() - if !ok { - return tcpip.ErrNotConnected + if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, header.IPv6ExtHdrSerializer{}); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) } + stat := netEP.stats.icmp.packetsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt); err != nil { + if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stat.dropped.Increment() return err } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 3fd1bacae..641c60b7c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -1320,13 +1321,13 @@ func TestLinkAddressRequest(t *testing.T) { name: "Unicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: tcpip.ErrBadLocalAddress, }, { name: "Multicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: tcpip.ErrBadLocalAddress, }, { name: "Unicast with no local address available", @@ -1341,58 +1342,58 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := s.CreateNIC(nicID, linkEP); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if len(test.nicAddr) != 0 { - if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } } - } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) - } + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } - if test.expectedErr != nil { - return - } + if test.expectedErr != nil { + return + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) - } - if pkt.Route.RemoteAddress != test.expectedRemoteAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) - } - if pkt.Route.LocalAddress != lladdr1 { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) - } - checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), - checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedRemoteAddr), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(lladdr0), - checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), - )) + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = test.expectedRemoteLinkAddr + if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedRemoteAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + }) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 40176594e..d658f9bcb 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -555,7 +555,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { +func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { extHdrsLen := extensionHeaders.Length() length := pkt.Size() + extensionHeaders.Length() if length > math.MaxUint16 { @@ -625,7 +625,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { return err } @@ -718,7 +718,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe stats := e.stats.ip linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { + if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { return 0, err } diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 78d86e523..c376016e9 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -249,7 +249,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp Data: buffer.View(icmp).ToVectorisedView(), }) - if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, }, extensionHeaders); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 41112a0c4..ca4ff621d 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -732,7 +732,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add }) sent := ndp.ep.stats.icmp.packetsSent - if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { @@ -1857,7 +1857,7 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index aed3042d1..7a22309e5 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -638,18 +638,12 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatal("expected an NDP NS response") } - if p.Route.LocalAddress != nicAddr { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } respNSDst := header.SolicitedNodeAddr(test.nsSrc) - if p.Route.RemoteAddress != respNSDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) - } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), |