summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go189
2 files changed, 193 insertions, 4 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6e692da07..b7d735889 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -483,10 +483,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- if to.Addr == header.IPv4Broadcast && !e.broadcast {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -503,6 +499,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
resolve = route.Resolve
}
+ if !e.broadcast && route.IsBroadcast() {
+ return 0, nil, tcpip.ErrBroadcastDisabled
+ }
+
if route.IsResolutionRequired() {
if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 90781cf49..66e8911c8 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -2142,3 +2142,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn
c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const nicID1 = 1
+
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ requiresBroadcastOpt bool
+ }{
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ requiresBroadcastOpt: false,
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
+ }
+ defer ep.Close()
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ to := tcpip.FullAddress{
+ Addr: test.remoteAddr,
+ Port: 80,
+ }
+ opts := tcpip.WriteOptions{To: &to}
+ expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ if !test.requiresBroadcastOpt {
+ expectedErrWithoutBcastOpt = nil
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+ })
+ }
+}