summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/link/rawfile/BUILD13
-rw-r--r--pkg/tcpip/link/rawfile/errors.go8
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go53
-rw-r--r--pkg/tcpip/stack/nic.go44
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
-rw-r--r--pkg/tcpip/tcpip.go13
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go61
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go120
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go68
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go32
11 files changed, 382 insertions, 81 deletions
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..6c410c5a6 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,14 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ srcs = [
+ "errors_test.go",
+ ],
+ library = "rawfile",
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index a0a873c84..604868fd8 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error
// *tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+// tcpip.ErrInvalidEndpointState (EINVAL).
func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if err := translations[e]; err != nil {
- return err
+ if e > 0 && e < syscall.Errno(len(translations)) {
+ if err := translations[e]; err != nil {
+ return err
+ }
}
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
new file mode 100644
index 000000000..e4cdc66bd
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestTranslateErrno(t *testing.T) {
+ for _, test := range []struct {
+ errno syscall.Errno
+ translated *tcpip.Error
+ }{
+ {
+ errno: syscall.Errno(0),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(maxErrno),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(514),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.EEXIST,
+ translated: tcpip.ErrDuplicateAddress,
+ },
+ } {
+ got := TranslateErrno(test.errno)
+ if got != test.translated {
+ t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 863ef6bee..1f1a1426b 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -665,33 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
}
}
- // Check if address is a broadcast address for the endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
n.mu.RUnlock()
return ref
}
}
-
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the IPv4 address is found in the NIC's subnets and the NIC
- // is a loopback interface.
- createTempEP := spoofingOrPromiscuous
- if !createTempEP && n.isLoopback() && protocol == header.IPv4ProtocolNumber {
- for _, r := range n.mu.endpoints {
- addr := r.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.Contains(address) {
- createTempEP = true
- break
- }
- }
- }
n.mu.RUnlock()
- if !createTempEP {
+ if !spoofingOrPromiscuous {
return nil
}
@@ -704,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
-// getRefForBroadcastLocked returns an endpoint where address is the IPv4
-// broadcast address for the endpoint's network.
+// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the
+// broadcast address for the endpoint's network or an address in the endpoint's
+// subnet if the NIC is a loopback interface. This matches linux behaviour.
//
-// n.mu MUST be read locked.
-func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+// n.mu MUST be read or write locked.
+func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint {
for _, ref := range n.mu.endpoints {
- // Only IPv4 has a notion of broadcast addresses.
+ // Only IPv4 has a notion of broadcast addresses or considers the loopback
+ // interface bound to an address's whole subnet (on linux).
if ref.protocol != header.IPv4ProtocolNumber {
continue
}
- addr := ref.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ subnet := ref.addrWithPrefix().Subnet()
+ if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() {
return ref
}
}
@@ -745,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
n.removeEndpointLocked(ref)
}
- // Check if address is a broadcast address for an endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
return ref
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..0774b5382 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a TCP packet with a non-unicast source or destination
// address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isInboundUnicast(r *Route) bool {
+ return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
+}
+
+func isOutboundUnicast(r *Route) bool {
+ return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 5e34e27ba..b2ddb24ec 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1194,6 +1194,19 @@ const (
TCPTimeWaitReuseLoopbackOnly
)
+// LingerOption is used by SetSockOpt/GetSockOpt to set/get the
+// duration for which a socket lingers before returning from Close.
+//
+// +stateify savable
+type LingerOption struct {
+ Enabled bool
+ Timeout time.Duration
+}
+
+func (*LingerOption) isGettableSocketOption() {}
+
+func (*LingerOption) isSettableSocketOption() {}
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 1b18023c5..fecbe7ba7 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -187,3 +187,64 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
})
}
}
+
+// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address
+// in a loopback interface's associated subnet is bound to the permanently bound
+// address.
+func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
+ const nicID = 1
+
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ addrBytes := []byte(ipv4Addr.Address)
+ addrBytes[len(addrBytes)-1]++
+ otherAddr := tcpip.Address(addrBytes)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ params := stack.NetworkHeaderParams{
+ Protocol: 111,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err)
+ }
+
+ // Removing the address should make the endpoint invalid.
+ if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
+ }
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 52c27e045..659acbc7a 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -436,3 +437,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
})
}
}
+
+// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
+// interested endpoints.
+func TestReuseAddrAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 9000
+ loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ )
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ broadcastAddr tcpip.Address
+ }{
+ {
+ name: "Subnet directed broadcast",
+ broadcastAddr: loopbackBroadcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ broadcastAddr: header.IPv4Broadcast,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x7f\x00\x00\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ // We use the empty subnet instead of just the loopback subnet so we
+ // also have a route to the IPv4 Broadcast address.
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // We create endpoints that bind to both the wildcard address and the
+ // broadcast address to make sure both of these types of "broadcast
+ // interested" endpoints receive broadcast packets.
+ wq := waiter.Queue{}
+ var eps []tcpip.Endpoint
+ for _, bindWildcard := range []bool{false, true} {
+ // Create multiple endpoints for each type of "broadcast interested"
+ // endpoint so we can test that all endpoints receive the broadcast
+ // packet.
+ for i := 0; i < 2; i++ {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if bindWildcard {
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ } else {
+ bindAddr.Addr = test.broadcastAddr
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ }
+
+ eps = append(eps, ep)
+ }
+ }
+
+ for i, wep := range eps {
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.broadcastAddr,
+ Port: localPort,
+ },
+ }
+ if n, _, err := wep.Write(data, writeOpts); err != nil {
+ t.Fatalf("eps[%d].Write(_, _): %s", i, err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ }
+
+ for j, rep := range eps {
+ if gotPayload, _, err := rep.Read(nil); err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6d5046a3d..faea7f2bb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -654,6 +654,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -1007,6 +1010,26 @@ func (e *endpoint) Close() {
return
}
+ if e.linger.Enabled && e.linger.Timeout == 0 {
+ s := e.EndpointState()
+ isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
+ if isResetState {
+ // Close the endpoint without doing full shutdown and
+ // send a RST.
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.closeNoShutdownLocked()
+
+ // Wake up worker to close the endpoint.
+ switch s {
+ case StateSynRecv:
+ e.notifyProtocolGoroutine(notifyClose)
+ default:
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ return
+ }
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -1807,6 +1830,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
+ case *tcpip.LingerOption:
+ e.LockUser()
+ e.linger = *v
+ e.UnlockUser()
+
default:
return nil
}
@@ -2031,6 +2059,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
+ case *tcpip.LingerOption:
+ e.LockUser()
+ *o = e.linger
+ e.UnlockUser()
+
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0d13e1efd..b1e5f1b24 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5214,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -5221,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5367,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index baf7df197..85e8c1c75 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -73,6 +73,18 @@ const (
testInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -184,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{