diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/dhcp/client.go | 3 | ||||
-rw-r--r-- | pkg/dhcp/dhcp_test.go | 3 | ||||
-rw-r--r-- | pkg/dhcp/server.go | 3 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 21 | ||||
-rw-r--r-- | pkg/syserr/netstack.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 70 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 23 |
11 files changed, 161 insertions, 11 deletions
diff --git a/pkg/dhcp/client.go b/pkg/dhcp/client.go index 3330c4998..6d48eec7e 100644 --- a/pkg/dhcp/client.go +++ b/pkg/dhcp/client.go @@ -141,6 +141,9 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg }, nil); err != nil { return Config{}, fmt.Errorf("dhcp: connect failed: %v", err) } + if err := ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + return Config{}, fmt.Errorf("dhcp: setsockopt SO_BROADCAST: %v", err) + } epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { diff --git a/pkg/dhcp/dhcp_test.go b/pkg/dhcp/dhcp_test.go index a21dce6bc..026064394 100644 --- a/pkg/dhcp/dhcp_test.go +++ b/pkg/dhcp/dhcp_test.go @@ -287,6 +287,9 @@ func TestTwoServers(t *testing.T) { if err = ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil { t.Fatalf("dhcp: server bind: %v", err) } + if err = ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + t.Fatalf("dhcp: setsockopt: %v", err) + } serverCtx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pkg/dhcp/server.go b/pkg/dhcp/server.go index 3e06ab4c7..c72c3b70d 100644 --- a/pkg/dhcp/server.go +++ b/pkg/dhcp/server.go @@ -123,6 +123,9 @@ func newEPConnServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Addr if err := ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil { return nil, fmt.Errorf("dhcp: server bind: %v", err) } + if err := ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + return nil, fmt.Errorf("dhcp: server setsockopt: %v", err) + } c := newEPConn(ctx, wq, ep) return NewServer(ctx, c, addrs, cfg) } diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index a97db5348..e24e58aed 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -582,6 +582,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) { + // TODO: Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_TYPE: if outLen < sizeOfInt32 { @@ -681,6 +682,18 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BROADCAST: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.BroadcastOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -982,6 +995,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BROADCAST: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 20e756edb..05ca475d1 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -43,6 +43,7 @@ var ( ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY) ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT) ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL) + ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES) ) var netstackErrorTranslations = map[*tcpip.Error]*Error{ @@ -80,6 +81,7 @@ var netstackErrorTranslations = map[*tcpip.Error]*Error{ tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, tcpip.ErrMessageTooLong: ErrMessageTooLong, tcpip.ErrNoBufferSpace: ErrNoBufferSpace, + tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, } // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 586ca873e..43d7c2ec4 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -399,6 +399,21 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) + // If the packet is destined to the IPv4 Broadcast address, then make a + // route to each IPv4 network endpoint and let each endpoint handle the + // packet. + if dst == header.IPv4Broadcast { + for _, ref := range n.endpoints { + if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) + r.RemoteLinkAddress = remote + ref.ep.HandlePacket(&r, vv) + ref.decRef() + } + } + return + } + if ref := n.getRef(protocol, dst); ref != nil { r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) r.RemoteLinkAddress = remote diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index a5ff2159a..c18208dc0 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -132,7 +132,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + // If this is a broadcast datagram, deliver the datagram to all endpoints + // managed by ep. + if id.LocalAddress == header.IPv4Broadcast { + for i, endpoint := range ep.endpointsArr { + // HandlePacket modifies vv, so each endpoint needs its own copy. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, vv) + break + } + vvCopy := buffer.NewView(vv.Size()) + copy(vvCopy, vv.ToView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + } + } else { + ep.selectEndpoint(id).HandlePacket(r, id, vv) + } } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. @@ -224,20 +239,47 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN } } -// deliverPacket attempts to deliver the given packet. Returns true if it found -// an endpoint, false otherwise. +var loopbackSubnet = func() tcpip.Subnet { + sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") + if err != nil { + panic(err) + } + return sn +}() + +// deliverPacket attempts to find one or more matching transport endpoints, and +// then, if matches are found, delivers the packet to them. Returns true if it +// found one or more endpoints, false otherwise. func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false } + // If a sender bound to the Loopback interface sends a broadcast, + // that broadcast must not be delivered to the sender. + if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { + return false + } + + // If the packet is a broadcast, then find all matching transport endpoints. + // Otherwise, try to find a single matching transport endpoint. + destEps := make([]TransportEndpoint, 0, 1) eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + + if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { + for epID, endpoint := range eps.endpoints { + if epID.LocalPort == id.LocalPort { + destEps = append(destEps, endpoint) + } + } + } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { + destEps = append(destEps, ep) + } eps.mu.RUnlock() - // Fail if we didn't find one. - if ep == nil { + // Fail if we didn't find at least one matching transport endpoint. + if len(destEps) == 0 { // UDP packet could not be delivered to an unknown destination port. if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() @@ -246,7 +288,9 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } // Deliver the packet. - ep.HandlePacket(r, id, vv) + for _, ep := range destEps { + ep.HandlePacket(r, id, vv) + } return true } @@ -277,7 +321,7 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { // Try to find a match with the id as provided. - if ep := eps.endpoints[id]; ep != nil { + if ep, ok := eps.endpoints[id]; ok { return ep } @@ -285,7 +329,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid := id nid.LocalAddress = "" - if ep := eps.endpoints[nid]; ep != nil { + if ep, ok := eps.endpoints[nid]; ok { return ep } @@ -293,11 +337,15 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = id.LocalAddress nid.RemoteAddress = "" nid.RemotePort = 0 - if ep := eps.endpoints[nid]; ep != nil { + if ep, ok := eps.endpoints[nid]; ok { return ep } // Try to find a match with only the local port. nid.LocalAddress = "" - return eps.endpoints[nid] + if ep, ok := eps.endpoints[nid]; ok { + return ep + } + + return nil } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a6e47397a..89e9d6741 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -100,6 +100,7 @@ var ( ErrNetworkUnreachable = &Error{msg: "network is unreachable"} ErrMessageTooLong = &Error{msg: "message too long"} ErrNoBufferSpace = &Error{msg: "no buffer space available"} + ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} ) // Errors related to Subnet @@ -502,6 +503,10 @@ type RemoveMembershipOption MembershipOption // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int +// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether +// datagram sockets are allowed to send packets to a broadcast address. +type BroadcastOption int + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination adddress in the row. @@ -527,6 +532,12 @@ func (r *Route) Match(addr Address) bool { return false } + // Using header.Ipv4Broadcast would introduce an import cycle, so + // we'll use a literal instead. + if addr == "\xff\xff\xff\xff" { + return true + } + for i := 0; i < len(r.Destination); i++ { if (addr[i] & r.Mask[i]) != r.Destination[i] { return false diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1ee9f8d25..aa31a78af 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -116,6 +116,9 @@ type endpoint struct { route stack.Route `state:"manual"` v6only bool isConnectNotified bool + // TCP should never broadcast but Linux nevertheless supports enabling/ + // disabling SO_BROADCAST, albeit as a NOOP. + broadcast bool // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -813,6 +816,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyKeepaliveChanged) return nil + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v != 0 + e.mu.Unlock() + return nil + default: return nil } @@ -971,6 +980,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 1 return nil + case *tcpip.BroadcastOption: + e.mu.Lock() + v := e.broadcast + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4891c7941..a07cd9011 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -336,6 +336,7 @@ func loadError(s string) *tcpip.Error { tcpip.ErrNetworkUnreachable, tcpip.ErrMessageTooLong, tcpip.ErrNoBufferSpace, + tcpip.ErrBroadcastDisabled, } messageToError = make(map[string]*tcpip.Error) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9c3881d63..05d35e526 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -82,6 +82,7 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID reusePort bool + broadcast bool // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -347,6 +348,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c nicid = e.bindNICID } + if to.Addr == header.IPv4Broadcast && !e.broadcast { + return 0, nil, tcpip.ErrBroadcastDisabled + } + r, _, _, err := e.connectRoute(nicid, *to) if err != nil { return 0, nil, err @@ -502,6 +507,13 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.reusePort = v != 0 e.mu.Unlock() + + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v != 0 + e.mu.Unlock() + + return nil } return nil } @@ -581,6 +593,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 0 return nil + case *tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + default: return tcpip.ErrUnknownProtocolOption } |