summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go126
-rw-r--r--pkg/tcpip/network/arp/arp_test.go110
-rw-r--r--pkg/tcpip/network/arp/stats_test.go44
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go32
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go96
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go39
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go87
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go36
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go82
-rw-r--r--pkg/tcpip/network/ip_test.go290
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go116
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go17
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go61
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go87
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go79
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go242
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go116
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go117
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go46
-rw-r--r--pkg/tcpip/network/ipv6/mld.go12
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go10
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go177
-rw-r--r--pkg/tcpip/network/testutil/testutil.go8
27 files changed, 1289 insertions, 748 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9ebf31b78..0caa65251 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -25,5 +25,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index c7ab876bf..933845269 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,7 +10,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6bc8c5c02..0d7fadc31 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -22,7 +22,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -35,6 +34,8 @@ const (
ProtocolNumber = header.ARPProtocolNumber
)
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
+
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
// considers the layer above the link-layer a network layer; the only
// facility provided by the stack to deliver packets to a layer above
@@ -49,15 +50,13 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
- nic stack.NetworkInterface
- linkAddrCache stack.LinkAddressCache
- nud stack.NUDHandler
- stats sharedStats
+ nic stack.NetworkInterface
+ stats sharedStats
}
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
e.setEnabled(true)
@@ -101,12 +100,10 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.protocol.forgetEndpoint(e.nic.ID())
-}
+func (*endpoint) Close() {}
-func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -115,12 +112,12 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
+ return 0, &tcpip.ErrNotSupported{}
}
-func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
@@ -151,10 +148,12 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
remoteAddr := tcpip.Address(h.ProtocolAddressSender())
remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
- } else {
- e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ switch err := e.nic.HandleNeighborProbe(header.IPv4ProtocolNumber, remoteAddr, remoteLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ARP but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
}
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -195,14 +194,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(addr, linkAddr)
- return
- }
-
// The solicited, override, and isRouter flags are not available for ARP;
// they are only available for IPv6 Neighbor Advertisements.
- e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{
+ switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{
// Solicited and unsolicited (also referred to as gratuitous) ARP Replies
// are handled equivalently to a solicited Neighbor Advertisement.
Solicited: true,
@@ -211,7 +205,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
Override: false,
// ARP does not distinguish between router and non-router hosts.
IsRouter: false,
- })
+ }); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ARP but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err))
+ }
}
}
@@ -221,19 +221,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
}
var _ stack.NetworkProtocol = (*protocol)(nil)
-var _ stack.LinkAddressResolver = (*protocol)(nil)
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
-
- mu struct {
- sync.RWMutex
-
- // eps is keyed by NICID to allow protocol methods to retrieve the correct
- // endpoint depending on the NIC.
- eps map[tcpip.NICID]*endpoint
- }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -244,12 +235,10 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
}
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
- protocol: p,
- nic: nic,
- linkAddrCache: linkAddrCache,
- nud: nud,
+ protocol: p,
+ nic: nic,
}
tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
@@ -257,60 +246,43 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
stackStats := p.stack.Stats()
e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
- p.mu.Lock()
- p.mu.eps[nic.ID()] = e
- p.mu.Unlock()
-
return e
}
-func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
- p.mu.Lock()
- defer p.mu.Unlock()
- delete(p.mu.eps, nicID)
-}
-
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return tcpip.ErrNotConnected
- }
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
+ nicID := e.nic.ID()
- stats := netEP.stats.arp
+ stats := e.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
if len(localAddr) == 0 {
- addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
if len(addr.Address) == 0 {
stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment()
- return tcpip.ErrNetworkUnreachable
+ return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addr.Address
- } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
stats.outgoingRequestBadLocalAddressErrors.Increment()
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -318,14 +290,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
h.SetOp(header.ARPRequest)
// TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
// link address.
- _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress())
if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
}
@@ -334,7 +306,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return header.EthernetBroadcastAddress, true
}
@@ -345,13 +317,13 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
// SetOption implements stack.NetworkProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.NetworkProtocol.Option.
-func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
@@ -369,9 +341,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{
stack: s,
- mu: struct {
- sync.RWMutex
- eps map[tcpip.NICID]*endpoint
- }{eps: make(map[tcpip.NICID]*endpoint)},
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 001fca727..24357e15d 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -125,8 +125,8 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
select {
case got := <-d.C:
- if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
}
case <-ctx.Done():
return fmt.Errorf("%s for %s", ctx.Err(), want)
@@ -491,9 +491,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
t.Fatal(err)
}
- neighbors, err := c.s.Neighbors(nicID)
+ neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber)
if err != nil {
- t.Fatalf("c.s.Neighbors(%d): %s", nicID, err)
+ t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
}
neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
@@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
-var _ stack.NetworkInterface = (*testInterface)(nil)
+var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil)
-type testInterface struct {
+type testLinkEndpoint struct {
stack.LinkEndpoint
- nicID tcpip.NICID
-
- writeErr *tcpip.Error
-}
-
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
-}
-
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
+ writeErr tcpip.Error
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
- var r stack.RouteInfo
- r.NetProto = protocol
- r.RemoteLinkAddress = remoteLinkAddr
return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -589,8 +556,8 @@ func TestLinkAddressRequest(t *testing.T) {
nicAddr tcpip.Address
localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- linkErr *tcpip.Error
- expectedErr *tcpip.Error
+ linkErr tcpip.Error
+ expectedErr tcpip.Error
expectedLocalAddr tcpip.Address
expectedRemoteLinkAddr tcpip.LinkAddress
expectedRequestsSent uint64
@@ -651,7 +618,7 @@ func TestLinkAddressRequest(t *testing.T) {
nicAddr: stackAddr,
localAddr: testAddr,
remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrBadLocalAddress,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
expectedRequestsSent: 0,
expectedRequestBadLocalAddressErrors: 1,
expectedRequestInterfaceHasNoLocalAddressErrors: 0,
@@ -662,7 +629,7 @@ func TestLinkAddressRequest(t *testing.T) {
nicAddr: stackAddr,
localAddr: testAddr,
remoteLinkAddr: "",
- expectedErr: tcpip.ErrBadLocalAddress,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
expectedRequestsSent: 0,
expectedRequestBadLocalAddressErrors: 1,
expectedRequestInterfaceHasNoLocalAddressErrors: 0,
@@ -673,7 +640,7 @@ func TestLinkAddressRequest(t *testing.T) {
nicAddr: "",
localAddr: "",
remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
expectedRequestsSent: 0,
expectedRequestBadLocalAddressErrors: 0,
expectedRequestInterfaceHasNoLocalAddressErrors: 1,
@@ -684,7 +651,7 @@ func TestLinkAddressRequest(t *testing.T) {
nicAddr: "",
localAddr: "",
remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
expectedRequestsSent: 0,
expectedRequestBadLocalAddressErrors: 0,
expectedRequestInterfaceHasNoLocalAddressErrors: 1,
@@ -695,8 +662,8 @@ func TestLinkAddressRequest(t *testing.T) {
nicAddr: stackAddr,
localAddr: stackAddr,
remoteLinkAddr: remoteLinkAddr,
- linkErr: tcpip.ErrInvalidEndpointState,
- expectedErr: tcpip.ErrInvalidEndpointState,
+ linkErr: &tcpip.ErrInvalidEndpointState{},
+ expectedErr: &tcpip.ErrInvalidEndpointState{},
expectedRequestsSent: 0,
expectedRequestBadLocalAddressErrors: 0,
expectedRequestInterfaceHasNoLocalAddressErrors: 0,
@@ -709,31 +676,31 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
})
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
+ if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
+ ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err)
+ }
+ linkRes, ok := ep.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
+ }
+
if len(test.nicAddr) != 0 {
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
}
}
- // We pass a test network interface to LinkAddressRequest with the same
- // NIC ID and link endpoint used by the NIC we created earlier so that we
- // can mock a link address request and observe the packets sent to the
- // link endpoint even though the stack uses the real NIC to validate the
- // local address.
- iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
- if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ {
+ err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
+ }
}
if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
@@ -781,18 +748,3 @@ func TestLinkAddressRequest(t *testing.T) {
})
}
}
-
-func TestLinkAddressRequestWithoutNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- })
- p := s.NetworkProtocolInstance(arp.ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
-
- if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrNotConnected {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrNotConnected)
- }
-}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index 036fdf739..65c708ac4 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -34,55 +34,13 @@ func (t *testInterface) ID() tcpip.NICID {
return t.nicID
}
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
-
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEndpointAfterClose {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
func TestMultiCounterStatsInitialization(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
var nic testInterface
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&nic, nil).(*endpoint)
// At this point, the Stack's stats and the NetworkEndpoint's stats are
// expected to be bound by a MultiCounterStat.
refStack := s.Stats()
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 1af87d713..243738951 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -84,7 +84,7 @@ type Fragmentation struct {
lowLimit int
reassemblers map[FragmentID]*reassembler
rList reassemblerList
- size int
+ memSize int
timeout time.Duration
blockSize uint16
clock tcpip.Clock
@@ -156,22 +156,22 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// the protocol to identify a fragment.
func (f *Fragmentation) Process(
id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
- buffer.VectorisedView, uint8, bool, error) {
+ *stack.PacketBuffer, uint8, bool, error) {
if first > last {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
}
if first%f.blockSize != 0 {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
}
fragmentSize := last - first + 1
if more && fragmentSize%f.blockSize != 0 {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
if l := pkt.Data.Size(); l != int(fragmentSize) {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
f.mu.Lock()
@@ -190,24 +190,24 @@ func (f *Fragmentation) Process(
}
f.mu.Unlock()
- res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt)
+ resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
f.release(r, false /* timedOut */)
f.mu.Unlock()
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
+ return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
- f.size += consumed
+ f.memSize += memConsumed
if done {
f.release(r, false /* timedOut */)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
- if f.size > f.highLimit {
- for f.size > f.lowLimit {
+ if f.memSize > f.highLimit {
+ for f.memSize > f.lowLimit {
tail := f.rList.Back()
if tail == nil {
break
@@ -216,7 +216,7 @@ func (f *Fragmentation) Process(
}
}
f.mu.Unlock()
- return res, firstFragmentProto, done, nil
+ return resPkt, firstFragmentProto, done, nil
}
func (f *Fragmentation) release(r *reassembler, timedOut bool) {
@@ -228,10 +228,10 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) {
delete(f.reassemblers, r.id)
f.rList.Remove(r)
- f.size -= r.size
- if f.size < 0 {
- log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
- f.size = 0
+ f.memSize -= r.memSize
+ if f.memSize < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize)
+ f.memSize = 0
}
if h := f.timeoutHandler; timedOut && h != nil {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 3a79688a8..905bbc19b 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -16,7 +16,6 @@ package fragmentation
import (
"errors"
- "reflect"
"testing"
"time"
@@ -112,20 +111,20 @@ func TestFragmentationProcess(t *testing.T) {
f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
+ resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
if err != nil {
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
}
- if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView())
- }
if done != c.out[i].done {
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
}
if c.out[i].done {
+ if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
+ }
if firstFragmentProto != proto {
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
@@ -173,9 +172,17 @@ func TestReassemblingTimeout(t *testing.T) {
// reassembly is done after the fragment is processd.
expectDone bool
- // sizeAfterEvent is the expected size of the fragmentation instance after
- // the event.
- sizeAfterEvent int
+ // memSizeAfterEvent is the expected memory size of the fragmentation
+ // instance after the event.
+ memSizeAfterEvent int
+ }
+
+ memSizeOfFrags := func(frags ...*fragment) int {
+ var size int
+ for _, frag := range frags {
+ size += pkt(len(frag.data), frag.data).MemSize()
+ }
+ return size
}
half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
@@ -189,16 +196,16 @@ func TestReassemblingTimeout(t *testing.T) {
name: "half1 and half2 are reassembled successfully",
events: []event{
{
- name: "half1",
- fragment: half1,
- expectDone: false,
- sizeAfterEvent: 1,
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ memSizeAfterEvent: memSizeOfFrags(half1),
},
{
- name: "half2",
- fragment: half2,
- expectDone: true,
- sizeAfterEvent: 0,
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ memSizeAfterEvent: 0,
},
},
},
@@ -206,36 +213,36 @@ func TestReassemblingTimeout(t *testing.T) {
name: "half1 timeout, half2 timeout",
events: []event{
{
- name: "half1",
- fragment: half1,
- expectDone: false,
- sizeAfterEvent: 1,
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ memSizeAfterEvent: memSizeOfFrags(half1),
},
{
- name: "half1 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- sizeAfterEvent: 1,
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ memSizeAfterEvent: memSizeOfFrags(half1),
},
{
- name: "half1 reassembly timeout",
- clockAdvance: 1,
- sizeAfterEvent: 0,
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ memSizeAfterEvent: 0,
},
{
- name: "half2",
- fragment: half2,
- expectDone: false,
- sizeAfterEvent: 1,
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ memSizeAfterEvent: memSizeOfFrags(half2),
},
{
- name: "half2 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- sizeAfterEvent: 1,
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ memSizeAfterEvent: memSizeOfFrags(half2),
},
{
- name: "half2 reassembly timeout",
- clockAdvance: 1,
- sizeAfterEvent: 0,
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ memSizeAfterEvent: 0,
},
},
},
@@ -255,8 +262,8 @@ func TestReassemblingTimeout(t *testing.T) {
t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
}
}
- if got, want := f.size, event.sizeAfterEvent; got != want {
- t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
+ if got, want := f.memSize, event.memSizeAfterEvent; got != want {
+ t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want)
}
}
})
@@ -264,7 +271,9 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil)
+ lowLimit := pkt(1, "0").MemSize()
+ highLimit := 3 * lowLimit // Allow at most 3 such packets.
+ f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send first fragment with id = 1.
@@ -288,15 +297,14 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil)
+ memSize := pkt(1, "0").MemSize()
+ f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send the same packet again.
f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
- got := f.size
- want := 1
- if got != want {
+ if got, want := f.memSize, memSize; got != want {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9b20bb1d8..933d63d32 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -29,13 +28,15 @@ type hole struct {
last uint16
filled bool
final bool
- data buffer.View
+ // pkt is the fragment packet if hole is filled. We keep the whole pkt rather
+ // than the fragmented payload to prevent binding to specific buffer types.
+ pkt *stack.PacketBuffer
}
type reassembler struct {
reassemblerEntry
id FragmentID
- size int
+ memSize int
proto uint8
mu sync.Mutex
holes []hole
@@ -59,18 +60,18 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
return r
}
-func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, 0, nil
+ return nil, 0, false, 0, nil
}
var holeFound bool
- var consumed int
+ var memConsumed int
for i := range r.holes {
currentHole := &r.holes[i]
@@ -90,12 +91,12 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
if first < currentHole.first || currentHole.last < last {
// Incoming fragment only partially fits in the free hole.
- return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
+ return nil, 0, false, 0, ErrFragmentOverlap
}
if !more {
if !currentHole.final || currentHole.filled && currentHole.last != last {
// We have another final fragment, which does not perfectly overlap.
- return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ return nil, 0, false, 0, ErrFragmentConflict
}
}
@@ -124,16 +125,15 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
})
currentHole.final = false
}
- v := pkt.Data.ToOwnedView()
- consumed = v.Size()
- r.size += consumed
+ memConsumed = pkt.MemSize()
+ r.memSize += memConsumed
// Update the current hole to precisely match the incoming fragment.
r.holes[i] = hole{
first: first,
last: last,
filled: true,
final: currentHole.final,
- data: v,
+ pkt: pkt,
}
r.filled++
// For IPv6, it is possible to have different Protocol values between
@@ -153,25 +153,24 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
}
if !holeFound {
// Incoming fragment is beyond end.
- return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ return nil, 0, false, 0, ErrFragmentConflict
}
// Check if all the holes have been filled and we are ready to reassemble.
if r.filled < len(r.holes) {
- return buffer.VectorisedView{}, 0, false, consumed, nil
+ return nil, 0, false, memConsumed, nil
}
sort.Slice(r.holes, func(i, j int) bool {
return r.holes[i].first < r.holes[j].first
})
- var size int
- views := make([]buffer.View, 0, len(r.holes))
- for _, hole := range r.holes {
- views = append(views, hole.data)
- size += hole.data.Size()
+ resPkt := r.holes[0].pkt
+ for i := 1; i < len(r.holes); i++ {
+ fragPkt := r.holes[i].pkt
+ fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size())
}
- return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
+ return resPkt, r.proto, true, memConsumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index 2ff03eeeb..214a93709 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -15,6 +15,7 @@
package fragmentation
import (
+ "bytes"
"math"
"testing"
@@ -44,16 +45,21 @@ func TestReassemblerProcess(t *testing.T) {
return payload
}
- pkt := func(size int) *stack.PacketBuffer {
+ pkt := func(sizes ...int) *stack.PacketBuffer {
+ var vv buffer.VectorisedView
+ for _, size := range sizes {
+ vv.AppendView(v(size))
+ }
return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v(size).ToVectorisedView(),
+ Data: vv,
})
}
var tests = []struct {
- name string
- params []processParams
- want []hole
+ name string
+ params []processParams
+ want []hole
+ wantPkt *stack.PacketBuffer
}{
{
name: "No fragments",
@@ -64,7 +70,7 @@ func TestReassemblerProcess(t *testing.T) {
name: "One fragment at beginning",
params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
want: []hole{
- {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)},
{first: 2, last: math.MaxUint16, filled: false, final: true},
},
},
@@ -72,7 +78,7 @@ func TestReassemblerProcess(t *testing.T) {
name: "One fragment in the middle",
params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
want: []hole{
- {first: 1, last: 2, filled: true, final: false, data: v(2)},
+ {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)},
{first: 0, last: 0, filled: false, final: false},
{first: 3, last: math.MaxUint16, filled: false, final: true},
},
@@ -81,7 +87,7 @@ func TestReassemblerProcess(t *testing.T) {
name: "One fragment at the end",
params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
want: []hole{
- {first: 1, last: 2, filled: true, final: true, data: v(2)},
+ {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)},
{first: 0, last: 0, filled: false},
},
},
@@ -89,8 +95,9 @@ func TestReassemblerProcess(t *testing.T) {
name: "One fragment completing a packet",
params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
want: []hole{
- {first: 0, last: 1, filled: true, final: true, data: v(2)},
+ {first: 0, last: 1, filled: true, final: true},
},
+ wantPkt: pkt(2),
},
{
name: "Two fragments completing a packet",
@@ -99,9 +106,10 @@ func TestReassemblerProcess(t *testing.T) {
{first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
},
want: []hole{
- {first: 0, last: 1, filled: true, final: false, data: v(2)},
- {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ {first: 0, last: 1, filled: true, final: false},
+ {first: 2, last: 3, filled: true, final: true},
},
+ wantPkt: pkt(2, 2),
},
{
name: "Two fragments completing a packet with a duplicate",
@@ -111,9 +119,10 @@ func TestReassemblerProcess(t *testing.T) {
{first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
},
want: []hole{
- {first: 0, last: 1, filled: true, final: false, data: v(2)},
- {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ {first: 0, last: 1, filled: true, final: false},
+ {first: 2, last: 3, filled: true, final: true},
},
+ wantPkt: pkt(2, 2),
},
{
name: "Two fragments completing a packet with a partial duplicate",
@@ -123,9 +132,10 @@ func TestReassemblerProcess(t *testing.T) {
{first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
},
want: []hole{
- {first: 0, last: 3, filled: true, final: false, data: v(4)},
- {first: 4, last: 5, filled: true, final: true, data: v(2)},
+ {first: 0, last: 3, filled: true, final: false},
+ {first: 4, last: 5, filled: true, final: true},
},
+ wantPkt: pkt(4, 2),
},
{
name: "Two overlapping fragments",
@@ -134,7 +144,7 @@ func TestReassemblerProcess(t *testing.T) {
{first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
},
want: []hole{
- {first: 0, last: 10, filled: true, final: false, data: v(11)},
+ {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)},
{first: 11, last: math.MaxUint16, filled: false, final: true},
},
},
@@ -145,7 +155,7 @@ func TestReassemblerProcess(t *testing.T) {
{first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
},
want: []hole{
- {first: 10, last: 14, filled: true, final: true, data: v(5)},
+ {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)},
{first: 0, last: 9, filled: false, final: false},
},
},
@@ -156,7 +166,7 @@ func TestReassemblerProcess(t *testing.T) {
{first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
},
want: []hole{
- {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
{first: 0, last: 4, filled: false, final: false},
},
},
@@ -167,7 +177,7 @@ func TestReassemblerProcess(t *testing.T) {
{first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
},
want: []hole{
- {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
{first: 0, last: 4, filled: false, final: false},
},
},
@@ -176,14 +186,47 @@ func TestReassemblerProcess(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ var resPkt *stack.PacketBuffer
+ var isDone bool
for _, param := range test.params {
- _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
if done != param.wantDone || err != param.wantError {
t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
}
+ if done {
+ resPkt = pkt
+ isDone = true
+ }
+ }
+
+ ignorePkt := func(a, b *stack.PacketBuffer) bool { return true }
+ cmpPktData := func(a, b *stack.PacketBuffer) bool {
+ if a == nil || b == nil {
+ return a == b
+ }
+ return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView())
}
- if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+
+ if isDone {
+ if diff := cmp.Diff(
+ test.want, r.holes,
+ cmp.AllowUnexported(hole{}),
+ // Do not compare pkt in hole. Data will be altered.
+ cmp.Comparer(ignorePkt),
+ ); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" {
+ t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff)
+ }
+ } else {
+ if diff := cmp.Diff(
+ test.want, r.holes,
+ cmp.AllowUnexported(hole{}),
+ cmp.Comparer(cmpPktData),
+ ); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
}
})
}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index f2f0e069c..b9f129728 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -126,6 +126,16 @@ type multicastGroupState struct {
//
// Must not be nil.
delayedReportJob *tcpip.Job
+
+ // delyedReportJobFiresAt is the time when the delayed report job will fire.
+ //
+ // A zero value indicates that the job is not scheduled.
+ delayedReportJobFiresAt time.Time
+}
+
+func (m *multicastGroupState) cancelDelayedReportJob() {
+ m.delayedReportJob.Cancel()
+ m.delayedReportJobFiresAt = time.Time{}
}
// GenericMulticastProtocolOptions holds options for the generic multicast
@@ -174,10 +184,10 @@ type MulticastGroupProtocol interface {
//
// Returns false if the caller should queue the report to be sent later. Note,
// returning false does not mean that the receiver hit an error.
- SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
+ SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error)
// SendLeave sends a multicast leave for the specified group address.
- SendLeave(groupAddress tcpip.Address) *tcpip.Error
+ SendLeave(groupAddress tcpip.Address) tcpip.Error
}
// GenericMulticastProtocolState is the per interface generic multicast protocol
@@ -428,7 +438,7 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad
// on that interface, it stops its timer and does not send a Report for
// that address, thus suppressing duplicate reports on the link.
if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
- info.delayedReportJob.Cancel()
+ info.cancelDelayedReportJob()
info.lastToSendReport = false
info.state = idleMember
g.memberships[groupAddress] = info
@@ -603,7 +613,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress
return
}
- info.delayedReportJob.Cancel()
+ info.cancelDelayedReportJob()
g.maybeSendLeave(groupAddress, info.lastToSendReport)
info.lastToSendReport = false
info.state = nonMember
@@ -645,14 +655,24 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
// If a timer for any address is already running, it is reset to the new
// random value only if the requested Maximum Response Delay is less than
// the remaining value of the running timer.
+ now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds())
if info.state == delayingMember {
- // TODO: Reset the timer if time remaining is greater than maxResponseTime.
- return
+ if info.delayedReportJobFiresAt.IsZero() {
+ panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress))
+ }
+
+ if info.delayedReportJobFiresAt.Sub(now) <= maxResponseTime {
+ // The timer is scheduled to fire before the maximum response time so we
+ // leave our timer as is.
+ return
+ }
}
info.state = delayingMember
- info.delayedReportJob.Cancel()
- info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
+ info.cancelDelayedReportJob()
+ maxResponseTime = g.calculateDelayTimerDuration(maxResponseTime)
+ info.delayedReportJob.Schedule(maxResponseTime)
+ info.delayedReportJobFiresAt = now.Add(maxResponseTime)
}
// calculateDelayTimerDuration returns a random time between (0, maxRespTime].
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 85593f211..60eaea37e 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -141,7 +141,7 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
@@ -158,7 +158,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
@@ -408,40 +408,46 @@ func TestHandleReport(t *testing.T) {
func TestHandleQuery(t *testing.T) {
tests := []struct {
- name string
- queryAddr tcpip.Address
- maxDelay time.Duration
- expectReportsFor []tcpip.Address
+ name string
+ queryAddr tcpip.Address
+ maxDelay time.Duration
+ expectQueriedReportsFor []tcpip.Address
+ expectDelayedReportsFor []tcpip.Address
}{
{
- name: "Unpecified empty",
- queryAddr: "",
- maxDelay: 0,
- expectReportsFor: []tcpip.Address{addr1, addr2},
+ name: "Unpecified empty",
+ queryAddr: "",
+ maxDelay: 0,
+ expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
+ expectDelayedReportsFor: nil,
},
{
- name: "Unpecified any",
- queryAddr: "\x00",
- maxDelay: 1,
- expectReportsFor: []tcpip.Address{addr1, addr2},
+ name: "Unpecified any",
+ queryAddr: "\x00",
+ maxDelay: 1,
+ expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
+ expectDelayedReportsFor: nil,
},
{
- name: "Specified",
- queryAddr: addr1,
- maxDelay: 2,
- expectReportsFor: []tcpip.Address{addr1},
+ name: "Specified",
+ queryAddr: addr1,
+ maxDelay: 2,
+ expectQueriedReportsFor: []tcpip.Address{addr1},
+ expectDelayedReportsFor: []tcpip.Address{addr2},
},
{
- name: "Specified all-nodes",
- queryAddr: addr3,
- maxDelay: 3,
- expectReportsFor: nil,
+ name: "Specified all-nodes",
+ queryAddr: addr3,
+ maxDelay: 3,
+ expectQueriedReportsFor: nil,
+ expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
},
{
- name: "Specified other",
- queryAddr: addr4,
- maxDelay: 4,
- expectReportsFor: nil,
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectQueriedReportsFor: nil,
+ expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
},
}
@@ -469,20 +475,20 @@ func TestHandleQuery(t *testing.T) {
if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // Receiving a query should make us schedule a new delayed report if it
- // is a query directed at us or a general query.
+ // Receiving a query should make us reschedule our delayed report timer
+ // to some time within the new max response delay.
mgp.handleQuery(test.queryAddr, test.maxDelay)
- if len(test.expectReportsFor) != 0 {
- clock.Advance(test.maxDelay)
- if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ clock.Advance(test.maxDelay)
+ if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The groups that were not affected by the query should still send a
+ // report after the max unsolicited report delay.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should have no more messages to send.
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 2a6ec19dc..6a1f11a36 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -18,6 +18,7 @@ import (
"strings"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -58,6 +59,14 @@ var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
PrefixLen: 120,
}
+type transportError struct {
+ origin tcpip.SockErrOrigin
+ typ uint8
+ code uint8
+ info uint32
+ kind stack.TransportErrorKind
+}
+
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
@@ -73,8 +82,7 @@ type testObject struct {
srcAddr tcpip.Address
dstAddr tcpip.Address
v4 bool
- typ stack.ControlType
- extra uint32
+ transErr transportError
dataCalls int
controlCalls int
@@ -118,16 +126,23 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb
return stack.TransportPacketHandled
}
-// DeliverTransportControlPacket is called by network endpoints after parsing
+// DeliverTransportError is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) {
t.checkValues(trans, pkt.Data, remote, local)
- if typ != t.typ {
- t.t.Errorf("typ = %v, want %v", typ, t.typ)
- }
- if extra != t.extra {
- t.t.Errorf("extra = %v, want %v", extra, t.extra)
+ if diff := cmp.Diff(
+ t.transErr,
+ transportError{
+ origin: transErr.Origin(),
+ typ: transErr.Type(),
+ code: transErr.Code(),
+ info: transErr.Info(),
+ kind: transErr.Kind(),
+ },
+ cmp.AllowUnexported(transportError{}),
+ ); diff != "" {
+ t.t.Errorf("transport error mismatch (-want +got):\n%s", diff)
}
t.controlCalls++
}
@@ -167,7 +182,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -189,7 +204,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
@@ -203,7 +218,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -219,7 +234,7 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -306,8 +321,16 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
+ return nil
+}
+
+func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
+ return nil
}
func TestSourceAddressValidation(t *testing.T) {
@@ -464,7 +487,7 @@ func TestEnableWhenNICDisabled(t *testing.T) {
// We pass nil for all parameters except the NetworkInterface and Stack
// since Enable only depends on these.
- ep := p.NewEndpoint(&nic, nil, nil, nil)
+ ep := p.NewEndpoint(&nic, nil)
// The endpoint should initially be disabled, regardless the NIC's enabled
// status.
@@ -479,8 +502,9 @@ func TestEnableWhenNICDisabled(t *testing.T) {
// Attempting to enable the endpoint while the NIC is disabled should
// fail.
nic.setEnabled(false)
- if err := ep.Enable(); err != tcpip.ErrNotPermitted {
- t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ err := ep.Enable()
+ if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
}
// ep should consider the NIC's enabled status when determining its own
// enabled status so we "enable" the NIC to read just the endpoint's
@@ -525,7 +549,7 @@ func TestIPv4Send(t *testing.T) {
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, nil)
+ ep := proto.NewEndpoint(&nic, nil)
defer ep.Close()
// Allocate and initialize the payload view.
@@ -659,7 +683,7 @@ func TestReceive(t *testing.T) {
v4: test.v4,
},
}
- ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject)
+ ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -692,24 +716,81 @@ func TestReceive(t *testing.T) {
}
func TestIPv4ReceiveControl(t *testing.T) {
- const mtu = 0xbeef - header.IPv4MinimumSize
+ const (
+ mtu = 0xbeef - header.IPv4MinimumSize
+ dataLen = 8
+ )
+
cases := []struct {
name string
expectedCount int
fragmentOffset uint16
code header.ICMPv4Code
- expectedTyp stack.ControlType
- expectedExtra uint32
+ transErr transportError
trunc int
}{
- {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
- {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
- {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
- {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
- {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8},
- {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
+ {
+ name: "FragmentationNeeded",
+ expectedCount: 1,
+ fragmentOffset: 0,
+ code: header.ICMPv4FragmentationNeeded,
+ transErr: transportError{
+ origin: tcpip.SockExtErrorOriginICMP,
+ typ: uint8(header.ICMPv4DstUnreachable),
+ code: uint8(header.ICMPv4FragmentationNeeded),
+ info: mtu,
+ kind: stack.PacketTooBigTransportError,
+ },
+ trunc: 0,
+ },
+ {
+ name: "Truncated (missing IPv4 header)",
+ expectedCount: 0,
+ fragmentOffset: 0,
+ code: header.ICMPv4FragmentationNeeded,
+ trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize,
+ },
+ {
+ name: "Truncated (partial offending packet's IP header)",
+ expectedCount: 0,
+ fragmentOffset: 0,
+ code: header.ICMPv4FragmentationNeeded,
+ trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1,
+ },
+ {
+ name: "Truncated (partial offending packet's data)",
+ expectedCount: 0,
+ fragmentOffset: 0,
+ code: header.ICMPv4FragmentationNeeded,
+ trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1,
+ },
+ {
+ name: "Port unreachable",
+ expectedCount: 1,
+ fragmentOffset: 0,
+ code: header.ICMPv4PortUnreachable,
+ transErr: transportError{
+ origin: tcpip.SockExtErrorOriginICMP,
+ typ: uint8(header.ICMPv4DstUnreachable),
+ code: uint8(header.ICMPv4PortUnreachable),
+ kind: stack.DestinationPortUnreachableTransportError,
+ },
+ trunc: 0,
+ },
+ {
+ name: "Non-zero fragment offset",
+ expectedCount: 0,
+ fragmentOffset: 100,
+ code: header.ICMPv4PortUnreachable,
+ trunc: 0,
+ },
+ {
+ name: "Zero-length packet",
+ expectedCount: 0,
+ fragmentOffset: 100,
+ code: header.ICMPv4PortUnreachable,
+ trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen,
+ },
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@@ -720,7 +801,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
+ ep := proto.NewEndpoint(&nic, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -728,7 +809,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
- view := buffer.NewView(dataOffset + 8)
+ view := buffer.NewView(dataOffset + dataLen)
// Create the outer IPv4 header.
ip := header.IPv4(view)
@@ -775,8 +856,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.srcAddr = remoteIPv4Addr
nic.testObject.dstAddr = localIPv4Addr
nic.testObject.contents = view[dataOffset:]
- nic.testObject.typ = c.expectedTyp
- nic.testObject.extra = c.expectedExtra
+ nic.testObject.transErr = c.transErr
addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
if !ok {
@@ -809,7 +889,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
+ ep := proto.NewEndpoint(&nic, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -904,7 +984,7 @@ func TestIPv6Send(t *testing.T) {
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, nil)
+ ep := proto.NewEndpoint(&nic, nil)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -943,30 +1023,112 @@ func TestIPv6Send(t *testing.T) {
}
func TestIPv6ReceiveControl(t *testing.T) {
+ const (
+ mtu = 0xffff
+ outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
+ dataLen = 8
+ )
+
newUint16 := func(v uint16) *uint16 { return &v }
- const mtu = 0xffff
- const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
+ portUnreachableTransErr := transportError{
+ origin: tcpip.SockExtErrorOriginICMP6,
+ typ: uint8(header.ICMPv6DstUnreachable),
+ code: uint8(header.ICMPv6PortUnreachable),
+ kind: stack.DestinationPortUnreachableTransportError,
+ }
+
cases := []struct {
name string
expectedCount int
fragmentOffset *uint16
typ header.ICMPv6Type
code header.ICMPv6Code
- expectedTyp stack.ControlType
- expectedExtra uint32
+ transErr transportError
trunc int
}{
- {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
- {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
- {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
- {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
- {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
- {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
- {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
+ {
+ name: "PacketTooBig",
+ expectedCount: 1,
+ fragmentOffset: nil,
+ typ: header.ICMPv6PacketTooBig,
+ code: header.ICMPv6UnusedCode,
+ transErr: transportError{
+ origin: tcpip.SockExtErrorOriginICMP6,
+ typ: uint8(header.ICMPv6PacketTooBig),
+ code: uint8(header.ICMPv6UnusedCode),
+ info: mtu,
+ kind: stack.PacketTooBigTransportError,
+ },
+ trunc: 0,
+ },
+ {
+ name: "Truncated (missing offending packet's IPv6 header)",
+ expectedCount: 0,
+ fragmentOffset: nil,
+ typ: header.ICMPv6PacketTooBig,
+ code: header.ICMPv6UnusedCode,
+ trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ name: "Truncated PacketTooBig (partial offending packet's IPv6 header)",
+ expectedCount: 0,
+ fragmentOffset: nil,
+ typ: header.ICMPv6PacketTooBig,
+ code: header.ICMPv6UnusedCode,
+ trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1,
+ },
+ {
+ name: "Truncated (partial offending packet's data)",
+ expectedCount: 0,
+ fragmentOffset: nil,
+ typ: header.ICMPv6PacketTooBig,
+ code: header.ICMPv6UnusedCode,
+ trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1,
+ },
+ {
+ name: "Port unreachable",
+ expectedCount: 1,
+ fragmentOffset: nil,
+ typ: header.ICMPv6DstUnreachable,
+ code: header.ICMPv6PortUnreachable,
+ transErr: portUnreachableTransErr,
+ trunc: 0,
+ },
+ {
+ name: "Truncated DstPortUnreachable (partial offending packet's IP header)",
+ expectedCount: 0,
+ fragmentOffset: nil,
+ typ: header.ICMPv6DstUnreachable,
+ code: header.ICMPv6PortUnreachable,
+ trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1,
+ },
+ {
+ name: "DstPortUnreachable for Fragmented, zero offset",
+ expectedCount: 1,
+ fragmentOffset: newUint16(0),
+ typ: header.ICMPv6DstUnreachable,
+ code: header.ICMPv6PortUnreachable,
+ transErr: portUnreachableTransErr,
+ trunc: 0,
+ },
+ {
+ name: "DstPortUnreachable for Non-zero fragment offset",
+ expectedCount: 0,
+ fragmentOffset: newUint16(100),
+ typ: header.ICMPv6DstUnreachable,
+ code: header.ICMPv6PortUnreachable,
+ transErr: portUnreachableTransErr,
+ trunc: 0,
+ },
+ {
+ name: "Zero-length packet",
+ expectedCount: 0,
+ fragmentOffset: nil,
+ typ: header.ICMPv6DstUnreachable,
+ code: header.ICMPv6PortUnreachable,
+ trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen,
+ },
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@@ -977,7 +1139,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
+ ep := proto.NewEndpoint(&nic, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -988,7 +1150,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
}
- view := buffer.NewView(dataOffset + 8)
+ view := buffer.NewView(dataOffset + dataLen)
// Create the outer IPv6 header.
ip := header.IPv6(view)
@@ -1039,8 +1201,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
nic.testObject.srcAddr = remoteIPv6Addr
nic.testObject.dstAddr = localIPv6Addr
nic.testObject.contents = view[dataOffset:]
- nic.testObject.typ = c.expectedTyp
- nic.testObject.extra = c.expectedExtra
+ nic.testObject.transErr = c.transErr
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
@@ -1122,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
- expectedErr *tcpip.Error
+ expectedErr tcpip.Error
}{
{
name: "IPv4",
@@ -1187,7 +1348,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
ip.SetHeaderLength(header.IPv4MinimumSize - 1)
return hdr.View().ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
{
name: "IPv4 too small",
@@ -1205,7 +1366,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
{
name: "IPv4 minimum size",
@@ -1465,7 +1626,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
}
@@ -1506,10 +1667,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
defer r.Release()
- if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr),
- })); err != test.expectedErr {
- t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ {
+ err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr),
+ }))
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
+ }
}
if test.expectedErr != nil {
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 330a7d170..9713c4448 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -44,6 +44,7 @@ go_test(
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 6bb97c46a..74e70e283 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -23,11 +23,108 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// icmpv4DestinationUnreachableSockError is a general ICMPv4 Destination
+// Unreachable error.
+//
+// +stateify savable
+type icmpv4DestinationUnreachableSockError struct{}
+
+// Origin implements tcpip.SockErrorCause.
+func (*icmpv4DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin {
+ return tcpip.SockExtErrorOriginICMP
+}
+
+// Type implements tcpip.SockErrorCause.
+func (*icmpv4DestinationUnreachableSockError) Type() uint8 {
+ return uint8(header.ICMPv4DstUnreachable)
+}
+
+// Info implements tcpip.SockErrorCause.
+func (*icmpv4DestinationUnreachableSockError) Info() uint32 {
+ return 0
+}
+
+var _ stack.TransportError = (*icmpv4DestinationHostUnreachableSockError)(nil)
+
+// icmpv4DestinationHostUnreachableSockError is an ICMPv4 Destination Host
+// Unreachable error.
+//
+// It indicates that a packet was not able to reach the destination host.
+//
+// +stateify savable
+type icmpv4DestinationHostUnreachableSockError struct {
+ icmpv4DestinationUnreachableSockError
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv4DestinationHostUnreachableSockError) Code() uint8 {
+ return uint8(header.ICMPv4HostUnreachable)
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv4DestinationHostUnreachableSockError) Kind() stack.TransportErrorKind {
+ return stack.DestinationHostUnreachableTransportError
+}
+
+var _ stack.TransportError = (*icmpv4DestinationPortUnreachableSockError)(nil)
+
+// icmpv4DestinationPortUnreachableSockError is an ICMPv4 Destination Port
+// Unreachable error.
+//
+// It indicates that a packet reached the destination host, but the transport
+// protocol was not active on the destination port.
+//
+// +stateify savable
+type icmpv4DestinationPortUnreachableSockError struct {
+ icmpv4DestinationUnreachableSockError
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv4DestinationPortUnreachableSockError) Code() uint8 {
+ return uint8(header.ICMPv4PortUnreachable)
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv4DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind {
+ return stack.DestinationPortUnreachableTransportError
+}
+
+var _ stack.TransportError = (*icmpv4FragmentationNeededSockError)(nil)
+
+// icmpv4FragmentationNeededSockError is an ICMPv4 Destination Unreachable error
+// due to fragmentation being required but the packet was set to not be
+// fragmented.
+//
+// It indicates that a link exists on the path to the destination with an MTU
+// that is too small to carry the packet.
+//
+// +stateify savable
+type icmpv4FragmentationNeededSockError struct {
+ icmpv4DestinationUnreachableSockError
+
+ mtu uint32
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv4FragmentationNeededSockError) Code() uint8 {
+ return uint8(header.ICMPv4FragmentationNeeded)
+}
+
+// Info implements tcpip.SockErrorCause.
+func (e *icmpv4FragmentationNeededSockError) Info() uint32 {
+ return e.mtu
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind {
+ return stack.PacketTooBigTransportError
+}
+
// handleControl handles the case when an ICMP error packet contains the headers
// of the original packet that caused the ICMP one to be sent. This information
// is used to find out which transport endpoint must be notified about the ICMP
// packet. We only expect the payload, not the enclosing ICMP packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return
@@ -54,10 +151,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
return
}
- // Skip the ip header, then deliver control message.
+ // Skip the ip header, then deliver the error.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
@@ -222,19 +319,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
case header.ICMPv4HostUnreachable:
- e.handleControl(stack.ControlNoRoute, 0, pkt)
-
+ e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
case header.ICMPv4PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, pkt)
-
+ e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt)
case header.ICMPv4FragmentationNeeded:
networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
if err != nil {
networkMTU = 0
}
- e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
+ e.handleControl(&icmpv4FragmentationNeededSockError{mtu: networkMTU}, pkt)
}
-
case header.ICMPv4SrcQuench:
received.srcQuench.Increment()
@@ -310,7 +404,7 @@ func (*icmpReasonParamProblem) isICMPReason() {}
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
origIPHdr := header.IPv4(pkt.NetworkHeader().View())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
@@ -376,7 +470,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
netEP, ok := p.mu.eps[pkt.NICID]
p.mu.Unlock()
if !ok {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
sent := netEP.stats.icmp.packetsSent
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index 4550aacd6..acc126c3b 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -103,7 +103,7 @@ func (igmp *igmpState) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
if igmp.v1Present() {
igmpType = header.IGMPv1MembershipReport
@@ -114,7 +114,7 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Erro
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
// saying we were the last host to report is cleared, this action MAY be
@@ -215,6 +215,11 @@ func (igmp *igmpState) setV1Present(v bool) {
}
}
+func (igmp *igmpState) resetV1Present() {
+ igmp.igmpV1Job.Cancel()
+ igmp.setV1Present(false)
+}
+
// handleMembershipQuery handles a membership query.
//
// Precondition: igmp.ep.mu must be locked.
@@ -242,7 +247,7 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
// writePacket assembles and sends an IGMP packet.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) {
igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
@@ -293,7 +298,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
// messages.
//
// If the group already exists in the membership map, returns
-// tcpip.ErrDuplicateAddress.
+// *tcpip.ErrDuplicateAddress.
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
@@ -312,13 +317,13 @@ func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
// if required.
//
// Precondition: igmp.ep.mu must be locked.
-func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of IGMP, but remains
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 1ee573ac8..95fd75ab7 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -101,10 +101,10 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
})
}
-// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
-// present on the network. The IGMP stack will then send IGMPv1 Membership
-// reports for backwards compatibility.
-func TestIgmpV1Present(t *testing.T) {
+// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1
+// router is detected. V1 present status is expected to be reset when the NIC
+// cycles.
+func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
@@ -116,14 +116,16 @@ func TestIgmpV1Present(t *testing.T) {
// This NIC will send an IGMPv2 report immediately, before this test can get
// the IGMPv1 General Membership Query in.
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
}
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
if t.Failed() {
t.FailNow()
}
@@ -145,19 +147,38 @@ func TestIgmpV1Present(t *testing.T) {
// Verify the solicited Membership Report is sent. Now that this NIC has seen
// an IGMPv1 query, it should send an IGMPv1 Membership Report.
- p, ok = e.Read()
- if ok {
+ if p, ok := e.Read(); ok {
t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
}
clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- p, ok = e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
- t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
+ }
+
+ // Cycling the interface should reset the V1 present flag.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
+ t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
}
- validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
}
func TestSendQueuedIGMPReports(t *testing.T) {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a05275a5b..b2d626107 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -101,11 +101,11 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
// Use the same control type as an ICMPv4 destination host unreachable error
// since the host is considered unreachable if we cannot resolve the link
// address to the next hop.
- e.handleControl(stack.ControlNoRoute, 0, pkt)
+ e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
dispatcher: dispatcher,
@@ -137,14 +137,14 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
}
// Enable implements stack.NetworkEndpoint.
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If the NIC is not enabled, the endpoint can't do anything meaningful so
// don't enable the endpoint.
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
// If the endpoint is already enabled, there is nothing for it to do.
@@ -212,7 +212,9 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
@@ -221,10 +223,18 @@ func (e *endpoint) disableLocked() {
e.mu.igmp.softLeaveAll()
// The address may have already been removed.
- if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+ // Reset the IGMP V1 present flag.
+ //
+ // If the node comes back up on the same network, it will re-learn that it
+ // needs to perform IGMPv1.
+ e.mu.igmp.resetV1Present()
+
if !e.setEnabled(false) {
panic("should have only done work to disable the endpoint if it was enabled")
}
@@ -256,7 +266,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error {
hdrLen := header.IPv4MinimumSize
var optLen int
if options != nil {
@@ -264,12 +274,12 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
}
hdrLen += optLen
if hdrLen > header.IPv4MaximumHeaderSize {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := pkt.Size()
if length > math.MaxUint16 {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
@@ -294,7 +304,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
// original packet.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
// Round the MTU down to align to 8 bytes.
fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -314,7 +324,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
return err
}
@@ -353,7 +363,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacket(r, gso, pkt, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
@@ -377,7 +387,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
}
if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -398,7 +408,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
@@ -423,7 +433,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pkt
- if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
@@ -488,22 +498,22 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
hdrLen := header.IPv4(h).HeaderLength()
if hdrLen < header.IPv4MinimumSize {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
h, ok = pkt.Data.PullUp(int(hdrLen))
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv4(h)
@@ -541,14 +551,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// wire format. We also want to check if the header's fields are valid before
// sending the packet.
if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv4(pkt.NetworkHeader().View())
ttl := h.TTL()
if ttl == 0 {
@@ -568,7 +578,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
networkEndpoint.(*endpoint).handlePacket(pkt)
return nil
}
- if err != tcpip.ErrBadAddress {
+ if _, ok := err.(*tcpip.ErrBadAddress); !ok {
return err
}
@@ -730,7 +740,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
proto := h.Protocol()
- data, _, ready, err := e.protocol.fragmentation.Process(
+ resPkt, _, ready, err := e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
@@ -753,7 +763,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
if !ready {
return
}
- pkt.Data = data
+ pkt = resPkt
+ h = header.IPv4(pkt.NetworkHeader().View())
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
@@ -825,7 +836,7 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -837,7 +848,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
@@ -894,7 +905,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
@@ -903,9 +914,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
e.mu.igmp.joinGroup(addr)
@@ -913,7 +924,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
@@ -922,7 +933,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error {
return e.mu.igmp.leaveGroup(addr)
}
@@ -995,24 +1006,24 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
p.SetDefaultTTL(uint8(*v))
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -1058,9 +1069,9 @@ func (p *protocol) SetForwarding(v bool) {
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
-func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
if linkMTU < header.IPv4MinimumMTU {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
// As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
@@ -1068,7 +1079,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Erro
// The maximal internet header is 60 octets, and a typical internet header
// is 20 octets, allowing a margin for headers of higher level protocols.
if networkHeaderSize > header.IPv4MaximumHeaderSize {
- return 0, tcpip.ErrMalformedHeader
+ return 0, &tcpip.ErrMalformedHeader{}
}
networkMTU := linkMTU
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index dac7cbfd4..a296bed79 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -38,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -78,8 +79,11 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Cannot connect using a broadcast address as the source.
- if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ {
+ err := ep.Connect(randomAddr)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{})
+ }
}
// However, we can bind to a broadcast address to listen.
@@ -1376,8 +1380,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize int
allowPackets int
outgoingErrors int
- mockError *tcpip.Error
- wantError *tcpip.Error
+ mockError tcpip.Error
+ wantError tcpip.Error
}{
{
description: "No frag",
@@ -1386,8 +1390,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 0,
outgoingErrors: 1,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag",
@@ -1396,8 +1400,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 0,
outgoingErrors: 3,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on second frag",
@@ -1406,8 +1410,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 1,
outgoingErrors: 2,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag MTU smaller than header",
@@ -1416,8 +1420,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize: 500,
allowPackets: 0,
outgoingErrors: 4,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error when MTU is smaller than IPv4 minimum MTU",
@@ -1427,7 +1431,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrInvalidEndpointState,
+ wantError: &tcpip.ErrInvalidEndpointState{},
},
}
@@ -1441,8 +1445,8 @@ func TestFragmentationErrors(t *testing.T) {
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.wantError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ if diff := cmp.Diff(ft.wantError, err); diff != "" {
+ t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff)
}
if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
@@ -2055,7 +2059,7 @@ func TestReceiveFragments(t *testing.T) {
// the fragment block size of 8 (RFC 791 section 3.1 page 14).
ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
- // Used to test the max reassembled payload length (65,535 octets).
+ // Used to test the max reassembled IPv4 payload length.
ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
@@ -2403,6 +2407,7 @@ func TestReceiveFragments(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ RawFactory: raw.EndpointFactory{},
})
e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2428,6 +2433,13 @@ func TestReceiveFragments(t *testing.T) {
t.Fatalf("Bind(%+v): %s", bindAddr, err)
}
+ // Bring up a raw endpoint so we can examine network headers.
+ epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */)
+ if err != nil {
+ t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
+ }
+ defer epRaw.Close()
+
// Prepare and send the fragments.
for _, frag := range test.fragments {
hdr := buffer.NewPrependable(header.IPv4MinimumSize)
@@ -2459,10 +2471,11 @@ func TestReceiveFragments(t *testing.T) {
}
for i, expectedPayload := range test.expectedPayloads {
+ // Check UDP payload delivered by UDP endpoint.
var buf bytes.Buffer
result, err := ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read: %s", i, err)
+ t.Fatalf("(i=%d) ep.Read: %s", i, err)
}
if diff := cmp.Diff(tcpip.ReadResult{
Count: len(expectedPayload),
@@ -2471,12 +2484,30 @@ func TestReceiveFragments(t *testing.T) {
t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
}
if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
- t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+
+ // Check IPv4 header in packet delivered by raw endpoint.
+ buf.Reset()
+ result, err = epRaw.Read(&buf, tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("(i=%d) epRaw.Read: %s", i, err)
+ }
+ // Reassambly does not take care of checksum. Here we write our own
+ // check routine instead of using checker.IPv4.
+ ip := header.IPv4(buf.Bytes())
+ for _, check := range []checker.NetworkChecker{
+ checker.FragmentFlags(0),
+ checker.FragmentOffset(0),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))),
+ } {
+ check(t, []header.Network{ip})
}
}
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -2554,11 +2585,11 @@ func TestWriteStats(t *testing.T) {
// Parameterize the tests to run with both WritePacket and WritePackets.
writers := []struct {
name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
}{
{
name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
@@ -2570,7 +2601,7 @@ func TestWriteStats(t *testing.T) {
},
}, {
name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
},
},
@@ -2580,7 +2611,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
index b28e7dcde..fbbc6e69c 100644
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -50,7 +50,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&nic, nil).(*endpoint)
var nicIDs []tcpip.NICID
proto.mu.Lock()
@@ -82,7 +82,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
var nic testInterface
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&nic, nil).(*endpoint)
// At this point, the Stack's stats and the NetworkEndpoint's stats are
// expected to be bound by a MultiCounterStat.
refStack := s.Stats()
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 95efada3a..dcfd93bab 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -23,11 +23,136 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// icmpv6DestinationUnreachableSockError is a general ICMPv6 Destination
+// Unreachable error.
+//
+// +stateify savable
+type icmpv6DestinationUnreachableSockError struct{}
+
+// Origin implements tcpip.SockErrorCause.
+func (*icmpv6DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin {
+ return tcpip.SockExtErrorOriginICMP6
+}
+
+// Type implements tcpip.SockErrorCause.
+func (*icmpv6DestinationUnreachableSockError) Type() uint8 {
+ return uint8(header.ICMPv6DstUnreachable)
+}
+
+// Info implements tcpip.SockErrorCause.
+func (*icmpv6DestinationUnreachableSockError) Info() uint32 {
+ return 0
+}
+
+var _ stack.TransportError = (*icmpv6DestinationNetworkUnreachableSockError)(nil)
+
+// icmpv6DestinationNetworkUnreachableSockError is an ICMPv6 Destination Network
+// Unreachable error.
+//
+// It indicates that the destination network is unreachable.
+//
+// +stateify savable
+type icmpv6DestinationNetworkUnreachableSockError struct {
+ icmpv6DestinationUnreachableSockError
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv6DestinationNetworkUnreachableSockError) Code() uint8 {
+ return uint8(header.ICMPv6NetworkUnreachable)
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv6DestinationNetworkUnreachableSockError) Kind() stack.TransportErrorKind {
+ return stack.DestinationNetworkUnreachableTransportError
+}
+
+var _ stack.TransportError = (*icmpv6DestinationPortUnreachableSockError)(nil)
+
+// icmpv6DestinationPortUnreachableSockError is an ICMPv6 Destination Port
+// Unreachable error.
+//
+// It indicates that a packet reached the destination host, but the transport
+// protocol was not active on the destination port.
+//
+// +stateify savable
+type icmpv6DestinationPortUnreachableSockError struct {
+ icmpv6DestinationUnreachableSockError
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv6DestinationPortUnreachableSockError) Code() uint8 {
+ return uint8(header.ICMPv6PortUnreachable)
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv6DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind {
+ return stack.DestinationPortUnreachableTransportError
+}
+
+var _ stack.TransportError = (*icmpv6DestinationAddressUnreachableSockError)(nil)
+
+// icmpv6DestinationAddressUnreachableSockError is an ICMPv6 Destination Address
+// Unreachable error.
+//
+// It indicates that a packet was not able to reach the destination.
+//
+// +stateify savable
+type icmpv6DestinationAddressUnreachableSockError struct {
+ icmpv6DestinationUnreachableSockError
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv6DestinationAddressUnreachableSockError) Code() uint8 {
+ return uint8(header.ICMPv6AddressUnreachable)
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv6DestinationAddressUnreachableSockError) Kind() stack.TransportErrorKind {
+ return stack.DestinationHostUnreachableTransportError
+}
+
+var _ stack.TransportError = (*icmpv6PacketTooBigSockError)(nil)
+
+// icmpv6PacketTooBigSockError is an ICMPv6 Packet Too Big error.
+//
+// It indicates that a link exists on the path to the destination with an MTU
+// that is too small to carry the packet.
+//
+// +stateify savable
+type icmpv6PacketTooBigSockError struct {
+ mtu uint32
+}
+
+// Origin implements tcpip.SockErrorCause.
+func (*icmpv6PacketTooBigSockError) Origin() tcpip.SockErrOrigin {
+ return tcpip.SockExtErrorOriginICMP6
+}
+
+// Type implements tcpip.SockErrorCause.
+func (*icmpv6PacketTooBigSockError) Type() uint8 {
+ return uint8(header.ICMPv6PacketTooBig)
+}
+
+// Code implements tcpip.SockErrorCause.
+func (*icmpv6PacketTooBigSockError) Code() uint8 {
+ return uint8(header.ICMPv6UnusedCode)
+}
+
+// Info implements tcpip.SockErrorCause.
+func (e *icmpv6PacketTooBigSockError) Info() uint32 {
+ return e.mtu
+}
+
+// Kind implements stack.TransportError.
+func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind {
+ return stack.PacketTooBigTransportError
+}
+
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
return
@@ -67,8 +192,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
p = fragHdr.TransportProtocol()
}
- // Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -175,7 +299,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
if err != nil {
networkMTU = 0
}
- e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
+ e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
@@ -187,11 +311,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
case header.ICMPv6NetworkUnreachable:
- e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
+ e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, pkt)
+ e.handleControl(&icmpv6DestinationPortUnreachableSockError{}, pkt)
}
-
case header.ICMPv6NeighborSolicit:
received.neighborSolicit.Increment()
if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
@@ -237,7 +360,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
+ default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
}
}
@@ -287,10 +412,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
} else if unspecifiedSource {
received.invalid.Increment()
return
- } else if e.nud != nil {
- e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr)
+ switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
+ }
}
// As per RFC 4861 section 7.1.1:
@@ -413,10 +542,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
+ return
+ default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
}
- return
}
it, err := na.Options().Iter(false /* check */)
@@ -441,20 +572,30 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
- // If the NA message has the target link layer option, update the link
- // address cache with the link address for the target of the message.
- if e.nud == nil {
- if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr)
- }
+ // As per RFC 4861 section 7.1.2:
+ // A node MUST silently discard any received Neighbor Advertisement
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP Destination Address is a multicast address the
+ // Solicited flag is zero.
+ if header.IsV6MulticastAddress(dstAddr) && na.SolicitedFlag() {
+ received.invalid.Increment()
return
}
- e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
+ switch err := e.nic.HandleNeighborConfirmation(ProtocolNumber, targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
Solicited: na.SolicitedFlag(),
Override: na.OverrideFlag(),
IsRouter: na.RouterFlag(),
- })
+ }); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err))
+ }
case header.ICMPv6EchoRequest:
received.echoRequest.Increment()
@@ -560,10 +701,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
- if e.nud != nil {
- // A RS with a specified source IP address modifies the NUD state
- // machine in the same way a reachability probe would.
- e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
+ // A RS with a specified source IP address modifies the neighbor table
+ // in the same way a regular probe would.
+ switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
}
}
@@ -612,8 +757,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
- if len(sourceLinkAddr) != 0 && e.nud != nil {
- e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
+ if len(sourceLinkAddr) != 0 {
+ switch err := e.nic.HandleNeighborProbe(ProtocolNumber, routerAddr, sourceLinkAddr); err.(type) {
+ case nil:
+ case *tcpip.ErrNotSupported:
+ // The stack may support ICMPv6 but the NIC may not need link resolution.
+ default:
+ panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err))
+ }
}
e.mu.Lock()
@@ -679,24 +830,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
}
-var _ stack.LinkAddressResolver = (*protocol)(nil)
-
// LinkAddressProtocol implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
- nicID := nic.ID()
-
- p.mu.Lock()
- netEP, ok := p.mu.eps[nicID]
- p.mu.Unlock()
- if !ok {
- return tcpip.ErrNotConnected
- }
-
+func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
remoteAddr := targetAddr
if len(remoteLinkAddr) == 0 {
remoteAddr = header.SolicitedNodeAddr(targetAddr)
@@ -704,22 +844,22 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
if len(localAddr) == 0 {
- addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
+ addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
if addressEndpoint == nil {
- return tcpip.ErrNetworkUnreachable
+ return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addressEndpoint.AddressWithPrefix().Address
- } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 {
- return tcpip.ErrBadLocalAddress
+ } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 {
+ return &tcpip.ErrBadLocalAddress{}
}
optsSerializer := header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
+ header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
@@ -736,9 +876,9 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- stat := netEP.stats.icmp.packetsSent
+ stat := e.stats.icmp.packetsSent
- if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stat.dropped.Increment()
return err
}
@@ -748,7 +888,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
-func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if header.IsV6MulticastAddress(addr) {
return header.EthernetAddressFromMulticastIPv6Address(addr), true
}
@@ -813,7 +953,7 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
origIPHdr := header.IPv6(pkt.NetworkHeader().View())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
@@ -884,7 +1024,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
netEP, ok := p.mu.eps[pkt.NICID]
p.mu.Unlock()
if !ok {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
sent := netEP.stats.icmp.packetsSent
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 641c60b7c..92f9ee2c2 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -79,7 +79,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
return nil
}
@@ -93,35 +93,14 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st
return stack.TransportPacketHandled
}
-var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil)
-
-type stubLinkAddressCache struct{}
-
-func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {}
-
-type stubNUDHandler struct {
- probeCount int
- confirmationCount int
-}
-
-var _ stack.NUDHandler = (*stubNUDHandler)(nil)
-
-func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) {
- s.probeCount++
-}
-
-func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
- s.confirmationCount++
-}
-
-func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) {
-}
-
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
stack.LinkEndpoint
+ probeCount int
+ confirmationCount int
+
nicID tcpip.NICID
}
@@ -145,21 +124,31 @@ func (*testInterface) Promiscuous() bool {
return false
}
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var r stack.RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt)
}
+func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
+ t.probeCount++
+ return nil
+}
+
+func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
+ t.confirmationCount++
+ return nil
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -202,7 +191,7 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -360,7 +349,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
+ ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -573,12 +562,19 @@ func newTestContext(t *testing.T) *testContext {
}},
)
- return c
-}
+ t.Cleanup(func() {
+ if err := c.s0.RemoveNIC(nicID); err != nil {
+ t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err)
+ }
+ if err := c.s1.RemoveNIC(nicID); err != nil {
+ t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err)
+ }
-func (c *testContext) cleanup() {
- c.linkEP0.Close()
- c.linkEP1.Close()
+ c.linkEP0.Close()
+ c.linkEP1.Close()
+ })
+
+ return c
}
type routeArgs struct {
@@ -628,7 +624,6 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
- defer c.cleanup()
r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -1283,7 +1278,7 @@ func TestLinkAddressRequest(t *testing.T) {
localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectedErr *tcpip.Error
+ expectedErr tcpip.Error
expectedRemoteAddr tcpip.Address
expectedRemoteLinkAddr tcpip.LinkAddress
}{
@@ -1321,23 +1316,23 @@ func TestLinkAddressRequest(t *testing.T) {
name: "Unicast with unassigned address",
localAddr: lladdr1,
remoteLinkAddr: linkAddr1,
- expectedErr: tcpip.ErrBadLocalAddress,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
},
{
name: "Multicast with unassigned address",
localAddr: lladdr1,
remoteLinkAddr: "",
- expectedErr: tcpip.ErrBadLocalAddress,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
},
{
name: "Unicast with no local address available",
remoteLinkAddr: linkAddr1,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
},
{
name: "Multicast with no local address available",
remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
},
}
@@ -1346,28 +1341,32 @@ func TestLinkAddressRequest(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- p := s.NetworkProtocolInstance(ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
- }
linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
if err := s.CreateNIC(nicID, linkEP); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+ linkRes, ok := ep.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
+ }
+
if len(test.nicAddr) != 0 {
if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
}
}
- // We pass a test network interface to LinkAddressRequest with the same NIC
- // ID and link endpoint used by the NIC we created earlier so that we can
- // mock a link address request and observe the packets sent to the link
- // endpoint even though the stack uses the real NIC.
- if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ {
+ err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
+ }
}
if test.expectedErr != nil {
@@ -1797,8 +1796,9 @@ func TestCallsToNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- nudHandler := &stubNUDHandler{}
- ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+
+ testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}
+ ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -1833,11 +1833,11 @@ func TestCallsToNeighborCache(t *testing.T) {
ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
- if nudHandler.probeCount != test.wantProbeCount {
- t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount)
+ if testInterface.probeCount != test.wantProbeCount {
+ t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount)
}
- if nudHandler.confirmationCount != test.wantConfirmationCount {
- t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount)
+ if testInterface.confirmationCount != test.wantConfirmationCount {
+ t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index d658f9bcb..c2e8c3ea7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -164,6 +164,7 @@ func getLabel(addr tcpip.Address) uint8 {
panic(fmt.Sprintf("should have a label for address = %s", addr))
}
+var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
@@ -172,13 +173,11 @@ var _ stack.NDPEndpoint = (*endpoint)(nil)
var _ NDPEndpoint = (*endpoint)(nil)
type endpoint struct {
- nic stack.NetworkInterface
- linkAddrCache stack.LinkAddressCache
- nud stack.NUDHandler
- dispatcher stack.TransportDispatcher
- protocol *protocol
- stack *stack.Stack
- stats sharedStats
+ nic stack.NetworkInterface
+ dispatcher stack.TransportDispatcher
+ protocol *protocol
+ stack *stack.Stack
+ stats sharedStats
// enabled is set to 1 when the endpoint is enabled and 0 when it is
// disabled.
@@ -236,7 +235,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
})
pkt.NICID = e.nic.ID()
pkt.NetworkProtocolNumber = ProtocolNumber
- e.handleControl(stack.ControlAddressUnreachable, 0, pkt)
+ e.handleControl(&icmpv6DestinationAddressUnreachableSockError{}, pkt)
}
// onAddressAssignedLocked handles an address being assigned.
@@ -307,17 +306,17 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
if addressEndpoint.GetKind() != stack.PermanentTentative {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
@@ -369,14 +368,14 @@ func (e *endpoint) transitionForwarding(forwarding bool) {
}
// Enable implements stack.NetworkEndpoint.
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If the NIC is not enabled, the endpoint can't do anything meaningful so
// don't enable the endpoint.
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
// If the endpoint is already enabled, there is nothing for it to do.
@@ -418,7 +417,7 @@ func (e *endpoint) Enable() *tcpip.Error {
//
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
- var err *tcpip.Error
+ var err tcpip.Error
e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
@@ -499,7 +498,9 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
@@ -555,11 +556,11 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error {
+func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error {
extHdrsLen := extensionHeaders.Length()
length := pkt.Size() + extensionHeaders.Length()
if length > math.MaxUint16 {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
@@ -585,7 +586,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta
// fragments left to be processed. The IP header must already be present in the
// original packet. The transport header protocol number is required to avoid
// parsing the IPv6 extension headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
networkHeader := header.IPv6(pkt.NetworkHeader().View())
// TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
@@ -598,13 +599,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// of 8 as per RFC 8200 section 4.5:
// Each complete fragment, except possibly the last ("rightmost") one, is
// an integer multiple of 8 octets long.
- return 0, 1, tcpip.ErrMessageTooLong
+ return 0, 1, &tcpip.ErrMessageTooLong{}
}
if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
// As per RFC 8200 Section 4.5, the Transport Header is expected to be small
// enough to fit in the first fragment.
- return 0, 1, tcpip.ErrMessageTooLong
+ return 0, 1, &tcpip.ErrMessageTooLong{}
}
pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
@@ -624,7 +625,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
return err
}
@@ -662,7 +663,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
@@ -685,7 +686,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
}
if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -707,7 +708,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
@@ -731,7 +732,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pb
- if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
@@ -798,11 +799,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required checks.
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv6(h)
@@ -827,14 +828,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// sending the packet.
proto, _, _, _, ok := parse.IPv6(pkt)
if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv6(pkt.NetworkHeader().View())
hopLimit := h.HopLimit()
if hopLimit <= 1 {
@@ -856,7 +857,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
networkEndpoint.(*endpoint).handlePacket(pkt)
return nil
}
- if err != tcpip.ErrBadAddress {
+ if _, ok := err.(*tcpip.ErrBadAddress); !ok {
return err
}
@@ -1165,7 +1166,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
- data, proto, ready, err := e.protocol.fragmentation.Process(
+ resPkt, proto, ready, err := e.protocol.fragmentation.Process(
// IPv6 ignores the Protocol field since the ID only needs to be unique
// across source-destination pairs, as per RFC 8200 section 4.5.
fragmentation.FragmentID{
@@ -1186,7 +1187,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
if ready {
- pkt.Data = data
+ pkt = resPkt
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
@@ -1330,7 +1331,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
@@ -1345,7 +1346,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
if err != nil {
return nil, err
@@ -1374,13 +1375,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
return e.removePermanentEndpointLocked(addressEndpoint, true)
@@ -1390,7 +1391,7 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
// it works with a stack.AddressEndpoint.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
unicast := header.IsV6UnicastAddress(addr.Address)
if unicast {
@@ -1415,12 +1416,12 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
+ err := e.leaveGroupLocked(snmc)
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); ok {
+ err = nil
}
-
- return nil
+ return err
}
// hasPermanentAddressLocked returns true if the endpoint has a permanent
@@ -1630,7 +1631,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
@@ -1639,9 +1640,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
e.mu.mld.joinGroup(addr)
@@ -1649,7 +1650,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
@@ -1658,7 +1659,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error {
return e.mu.mld.leaveGroup(addr)
}
@@ -1730,13 +1731,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
- nic: nic,
- linkAddrCache: linkAddrCache,
- nud: nud,
- dispatcher: dispatcher,
- protocol: p,
+ nic: nic,
+ dispatcher: dispatcher,
+ protocol: p,
}
e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
@@ -1762,24 +1761,24 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
p.SetDefaultTTL(uint8(*v))
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -1842,9 +1841,9 @@ func (p *protocol) SetForwarding(v bool) {
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
// which includes the length of the extension headers.
-func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error) {
if linkMTU < header.IPv6MinimumMTU {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
// As per RFC 7112 section 5, we should discard packets if their IPv6 header
@@ -1855,7 +1854,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Erro
// bytes ensures that the header chain length does not exceed the IPv6
// minimum MTU.
if networkHeadersLen > header.IPv6MinimumMTU {
- return 0, tcpip.ErrMalformedHeader
+ return 0, &tcpip.ErrMalformedHeader{}
}
networkMTU := linkMTU - uint32(networkHeadersLen)
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 5276878a0..1c6c37c91 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -996,8 +996,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}
// Should not have any more UDP packets.
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -1988,8 +1989,9 @@ func TestReceiveIPv6Fragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -2472,11 +2474,11 @@ func TestWriteStats(t *testing.T) {
writers := []struct {
name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
}{
{
name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
@@ -2488,7 +2490,7 @@ func TestWriteStats(t *testing.T) {
},
}, {
name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
},
},
@@ -2498,7 +2500,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -2597,7 +2599,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
var nic testInterface
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&nic, nil).(*endpoint)
var nicIDs []tcpip.NICID
proto.mu.Lock()
@@ -2832,8 +2834,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize int
allowPackets int
outgoingErrors int
- mockError *tcpip.Error
- wantError *tcpip.Error
+ mockError tcpip.Error
+ wantError tcpip.Error
}{
{
description: "No frag",
@@ -2842,8 +2844,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 0,
outgoingErrors: 1,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag",
@@ -2852,8 +2854,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 0,
outgoingErrors: 3,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on second frag",
@@ -2862,8 +2864,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 1,
outgoingErrors: 2,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error when MTU is smaller than transport header",
@@ -2873,7 +2875,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrMessageTooLong,
+ wantError: &tcpip.ErrMessageTooLong{},
},
{
description: "Error when MTU is smaller than IPv6 minimum MTU",
@@ -2883,7 +2885,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrInvalidEndpointState,
+ wantError: &tcpip.ErrInvalidEndpointState{},
},
}
@@ -2897,8 +2899,8 @@ func TestFragmentationErrors(t *testing.T) {
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.wantError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ if diff := cmp.Diff(ft.wantError, err); diff != "" {
+ t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff)
}
if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
@@ -3073,7 +3075,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
var nic testInterface
- ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&nic, nil).(*endpoint)
// At this point, the Stack's stats and the NetworkEndpoint's stats are
// supposed to be bound.
refStack := s.Stats()
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index c376016e9..2cc0dfebd 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -68,14 +68,14 @@ func (mld *mldState) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
}
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
_, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
@@ -112,7 +112,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
// joinGroup handles joining a new group and sending and scheduling the required
// messages.
//
-// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+// If the group is already joined, returns *tcpip.ErrDuplicateAddress.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
@@ -131,13 +131,13 @@ func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
// required.
//
// Precondition: mld.ep.mu must be locked.
-func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of MLD, but remains
@@ -166,7 +166,7 @@ func (mld *mldState) sendQueuedReports() {
// writePacket assembles and sends an MLD packet.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) {
sentStats := mld.ep.stats.icmp.packetsSent
var mldStat tcpip.MultiCounterStat
switch mldType {
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index ca4ff621d..d7dde1767 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -241,7 +241,7 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error)
// OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
@@ -614,10 +614,10 @@ type slaacPrefixState struct {
// tentative.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
- return tcpip.ErrAddressFamilyNotSupported
+ return &tcpip.ErrAddressFamilyNotSupported{}
}
if addressEndpoint.GetKind() != stack.PermanentTentative {
@@ -666,7 +666,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
dadDone := remaining == 0
- var err *tcpip.Error
+ var err tcpip.Error
if !dadDone {
err = ndp.sendDADPacket(addr, addressEndpoint)
}
@@ -717,7 +717,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 7a22309e5..8edaa9508 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -63,7 +63,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
if err := ep.Enable(); err != nil {
t.Fatalf("ep.Enable(): %s", err)
}
@@ -90,7 +90,7 @@ type testNDPDispatcher struct {
addr tcpip.Address
}
-func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
}
func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
@@ -167,10 +167,10 @@ type linkResolutionResult struct {
ok bool
}
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
-func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -199,6 +199,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -242,17 +243,19 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
})
wantInvalid := uint64(0)
- wantErr := (*tcpip.Error)(nil)
wantSucccess := true
if len(test.expectedLinkAddr) == 0 {
wantInvalid = 1
- wantErr = tcpip.ErrWouldBlock
wantSucccess = false
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
+ }
+ } else {
+ if err != nil {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
}
- if err != wantErr {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr)
- }
if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
}
@@ -263,11 +266,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
-// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
// that receiving a valid NDP NS message with the Source Link Layer Address
// option results in a new entry in the link address cache for the sender of
// the message.
-func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -335,18 +338,18 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
Data: hdr.View().ToVectorisedView(),
}))
- neighbors, err := s.Neighbors(nicID)
+ neighbors, err := s.Neighbors(nicID, ProtocolNumber)
if err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
}
neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
for _, n := range neighbors {
if existing, ok := neighborByAddr[n.Addr]; ok {
if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
}
neighborByAddr[n.Addr] = n
}
@@ -380,7 +383,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
}
}
-func TestNeighorSolicitationResponse(t *testing.T) {
+func TestNeighborSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
remoteAddr := lladdr1
@@ -719,10 +722,10 @@ func TestNeighorSolicitationResponse(t *testing.T) {
}
}
-// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
// valid NDP NA message with the Target Link Layer Address option results in a
// new entry in the link address cache for the target of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -756,8 +759,10 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseLinkAddrCache: true,
})
e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -801,17 +806,19 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
})
wantInvalid := uint64(0)
- wantErr := (*tcpip.Error)(nil)
wantSucccess := true
if len(test.expectedLinkAddr) == 0 {
wantInvalid = 1
- wantErr = tcpip.ErrWouldBlock
wantSucccess = false
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
+ }
+ } else {
+ if err != nil {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
}
- if err != wantErr {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr)
- }
if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
}
@@ -822,11 +829,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
// that receiving a valid NDP NA message with the Target Link Layer Address
// option does not result in a new entry in the neighbor cache for the target
// of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -901,18 +908,18 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
Data: hdr.View().ToVectorisedView(),
}))
- neighbors, err := s.Neighbors(nicID)
+ neighbors, err := s.Neighbors(nicID, ProtocolNumber)
if err != nil {
- t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
}
neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
for _, n := range neighbors {
if existing, ok := neighborByAddr[n.Addr]; ok {
if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
+ t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
}
neighborByAddr[n.Addr] = n
}
@@ -1168,6 +1175,118 @@ func TestNDPValidation(t *testing.T) {
}
+// TestNeighborAdvertisementValidation tests that the NIC validates received
+// Neighbor Advertisements.
+//
+// In particular, if the IP Destination Address is a multicast address, and the
+// Solicited flag is not zero, the Neighbor Advertisement is invalid and should
+// be discarded.
+func TestNeighborAdvertisementValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ ipDstAddr tcpip.Address
+ solicitedFlag bool
+ valid bool
+ }{
+ {
+ name: "Multicast IP destination address with Solicited flag set",
+ ipDstAddr: header.IPv6AllNodesMulticastAddress,
+ solicitedFlag: true,
+ valid: false,
+ },
+ {
+ name: "Multicast IP destination address with Solicited flag unset",
+ ipDstAddr: header.IPv6AllNodesMulticastAddress,
+ solicitedFlag: false,
+ valid: true,
+ },
+ {
+ name: "Unicast IP destination address with Solicited flag set",
+ ipDstAddr: lladdr0,
+ solicitedFlag: true,
+ valid: true,
+ },
+ {
+ name: "Unicast IP destination address with Solicited flag unset",
+ ipDstAddr: lladdr0,
+ solicitedFlag: false,
+ valid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
+ na.SetTargetAddress(lladdr1)
+ na.SetSolicitedFlag(test.solicitedFlag)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: test.ipDstAddr,
+ })
+
+ stats := s.Stats().ICMP.V6.PacketsReceived
+ invalid := stats.Invalid
+ rxNA := stats.NeighborAdvert
+
+ if got := rxNA.Value(); got != 0 {
+ t.Fatalf("got rxNA = %d, want = 0", got)
+ }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ if got := rxNA.Value(); got != 1 {
+ t.Fatalf("got rxNA = %d, want = 1", got)
+ }
+ var wantInvalid uint64 = 1
+ if test.valid {
+ wantInvalid = 0
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Fatalf("got invalid = %d, want = %d", got, wantInvalid)
+ }
+ // As per RFC 4861 section 7.2.5:
+ // When a valid Neighbor Advertisement is received ...
+ // If no entry exists, the advertisement SHOULD be silently discarded.
+ // There is no need to create an entry if none exists, since the
+ // recipient has apparently not initiated any communication with the
+ // target.
+ if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil {
+ t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+ })
+ }
+}
+
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 9bd009374..f5fa77b65 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -35,7 +35,7 @@ type MockLinkEndpoint struct {
WrittenPackets []*stack.PacketBuffer
mtu uint32
- err *tcpip.Error
+ err tcpip.Error
allowPackets int
}
@@ -43,7 +43,7 @@ type MockLinkEndpoint struct {
//
// err is the error that will be returned once allowPackets packets are written
// to the endpoint.
-func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
return &MockLinkEndpoint{
mtu: mtu,
err: err,
@@ -64,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if ep.allowPackets == 0 {
return ep.err
}
@@ -74,7 +74,7 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
var n int
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {